aidial_adapter_openai/gpt4_multi_modal/chat_completion.py (259 lines of code) (raw):
import os
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
List,
Optional,
Tuple,
TypeVar,
cast,
)
import aiohttp
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, Response
from aidial_adapter_openai.dial_api.storage import FileStorage
from aidial_adapter_openai.gpt4_multi_modal.gpt4_vision import (
convert_gpt4v_to_gpt4_chunk,
)
from aidial_adapter_openai.gpt4_multi_modal.transformation import (
SUPPORTED_FILE_EXTS,
ResourceProcessor,
)
from aidial_adapter_openai.utils.auth import OpenAICreds, get_auth_headers
from aidial_adapter_openai.utils.chat_completion_response import (
ChatCompletionBlock,
)
from aidial_adapter_openai.utils.log_config import logger
from aidial_adapter_openai.utils.multi_modal_message import MultiModalMessage
from aidial_adapter_openai.utils.sse_stream import parse_openai_sse_stream
from aidial_adapter_openai.utils.streaming import (
create_response_from_chunk,
create_stage_chunk,
generate_stream,
map_stream,
prepend_to_stream,
)
from aidial_adapter_openai.utils.tokenizer import MultiModalTokenizer
from aidial_adapter_openai.utils.truncate_prompt import (
DiscardedMessages,
TruncatedTokens,
truncate_prompt,
)
# The built-in default max_tokens is 16 tokens,
# which is too small for most image-to-text use cases.
GPT4V_DEFAULT_MAX_TOKENS = int(os.getenv("GPT4_VISION_MAX_TOKENS", "1024"))
USAGE = f"""
### Usage
The application answers queries about attached images.
Attach images and ask questions about them.
Supported image types: {', '.join(SUPPORTED_FILE_EXTS)}.
Examples of queries:
- "Describe this picture" for one image,
- "What are in these images? Is there any difference between them?" for multiple images.
""".strip()
async def transpose_stream(
stream: AsyncIterator[bytes | Response],
) -> AsyncIterator[bytes] | Response:
first_chunk: Optional[bytes] = None
async for chunk in stream:
if isinstance(chunk, Response):
return chunk
else:
first_chunk = chunk
break
stream = cast(AsyncIterator[bytes], stream)
if first_chunk is not None:
stream = prepend_to_stream(first_chunk, stream)
return stream
async def predict_stream(
api_url: str, headers: Dict[str, str], request: Any
) -> AsyncIterator[bytes] | Response:
return await transpose_stream(predict_stream_raw(api_url, headers, request))
async def predict_stream_raw(
api_url: str, headers: Dict[str, str], request: Any
) -> AsyncIterator[bytes | Response]:
async with aiohttp.ClientSession() as session:
async with session.post(
api_url, json=request, headers=headers
) as response:
if response.status != 200:
yield JSONResponse(
status_code=response.status, content=await response.json()
)
return
async for line in response.content:
yield line
async def predict_non_stream(
api_url: str, headers: Dict[str, str], request: Any
) -> dict | JSONResponse:
async with aiohttp.ClientSession() as session:
async with session.post(
api_url, json=request, headers=headers
) as response:
if response.status != 200:
return JSONResponse(
status_code=response.status, content=await response.json()
)
return await response.json()
def multi_modal_truncate_prompt(
request: dict,
messages: List[MultiModalMessage],
max_prompt_tokens: int,
tokenizer: MultiModalTokenizer,
) -> Tuple[List[MultiModalMessage], DiscardedMessages, TruncatedTokens]:
return truncate_prompt(
messages=messages,
message_tokens=tokenizer.tokenize_request_message,
is_system_message=lambda message: message.raw_message["role"]
== "system",
max_prompt_tokens=max_prompt_tokens,
initial_prompt_tokens=tokenizer.tokenize_request(request, []),
)
async def gpt4o_chat_completion(
request: Any,
deployment: str,
upstream_endpoint: str,
creds: OpenAICreds,
is_stream: bool,
file_storage: Optional[FileStorage],
api_version: str,
tokenizer: MultiModalTokenizer,
eliminate_empty_choices: bool,
):
return await chat_completion(
request,
deployment,
upstream_endpoint,
creds,
is_stream,
file_storage,
api_version,
tokenizer,
lambda x: x,
None,
eliminate_empty_choices,
)
async def gpt4_vision_chat_completion(
request: Any,
deployment: str,
upstream_endpoint: str,
creds: OpenAICreds,
is_stream: bool,
file_storage: Optional[FileStorage],
api_version: str,
tokenizer: MultiModalTokenizer,
eliminate_empty_choices: bool,
):
return await chat_completion(
request,
deployment,
upstream_endpoint,
creds,
is_stream,
file_storage,
api_version,
tokenizer,
convert_gpt4v_to_gpt4_chunk,
GPT4V_DEFAULT_MAX_TOKENS,
eliminate_empty_choices,
)
async def chat_completion(
request: Any,
deployment: str,
upstream_endpoint: str,
creds: OpenAICreds,
is_stream: bool,
file_storage: Optional[FileStorage],
api_version: str,
tokenizer: MultiModalTokenizer,
response_transformer: Callable[[dict], dict | None],
default_max_tokens: Optional[int],
eliminate_empty_choices: bool,
):
if request.get("n", 1) > 1:
raise RequestValidationError("The deployment doesn't support n > 1")
messages: List[Any] = request["messages"]
if len(messages) == 0:
raise RequestValidationError("The request doesn't contain any messages")
api_url = f"{upstream_endpoint}?api-version={api_version}"
transform_result = await ResourceProcessor(
file_storage=file_storage
).transform_messages(messages)
if isinstance(transform_result, DialException):
logger.error(f"Failed to prepare request: {transform_result.message}")
chunk = create_stage_chunk("Usage", USAGE, is_stream)
return create_response_from_chunk(chunk, transform_result, is_stream)
multi_modal_messages = transform_result
discarded_messages = None
max_prompt_tokens = request.pop("max_prompt_tokens", None)
if max_prompt_tokens is not None:
multi_modal_messages, discarded_messages, estimated_prompt_tokens = (
multi_modal_truncate_prompt(
request=request,
messages=multi_modal_messages,
max_prompt_tokens=max_prompt_tokens,
tokenizer=tokenizer,
)
)
logger.debug(
f"prompt tokens after truncation: {estimated_prompt_tokens}"
)
else:
estimated_prompt_tokens = tokenizer.tokenize_request(
request, multi_modal_messages
)
logger.debug(
f"prompt tokens without truncation: {estimated_prompt_tokens}"
)
request = {
**request,
"max_tokens": request.get("max_tokens") or default_max_tokens,
"messages": [m.raw_message for m in multi_modal_messages],
}
headers = get_auth_headers(creds)
if is_stream:
response = await predict_stream(api_url, headers, request)
if isinstance(response, Response):
return response
T = TypeVar("T")
def debug_print(chunk: T) -> T:
logger.debug(f"chunk: {chunk}")
return chunk
return map_stream(
debug_print,
generate_stream(
stream=map_stream(
response_transformer,
parse_openai_sse_stream(response),
),
get_prompt_tokens=lambda: estimated_prompt_tokens,
tokenize_response=tokenizer.tokenize_response,
deployment=deployment,
discarded_messages=discarded_messages,
eliminate_empty_choices=eliminate_empty_choices,
),
)
else:
response = await predict_non_stream(api_url, headers, request)
if isinstance(response, Response):
return response
response = response_transformer(response)
if response is None:
raise DialException(
status_code=500,
message="The origin returned invalid response",
type="invalid_response_error",
)
if discarded_messages:
response |= {
"statistics": {"discarded_messages": discarded_messages}
}
if usage := response.get("usage"):
actual_prompt_tokens = usage["prompt_tokens"]
if actual_prompt_tokens != estimated_prompt_tokens:
logger.warning(
f"Estimated prompt tokens ({estimated_prompt_tokens}) don't match the actual ones ({actual_prompt_tokens})"
)
actual_completion_tokens = usage["completion_tokens"]
estimated_completion_tokens = tokenizer.tokenize_response(
ChatCompletionBlock(resp=response)
)
if actual_completion_tokens != estimated_completion_tokens:
logger.warning(
f"Estimated completion tokens ({estimated_completion_tokens}) don't match the actual ones ({actual_completion_tokens})"
)
return response