aidial_adapter_openai/endpoints/chat_completion.py (136 lines of code) (raw):

from typing import assert_never from fastapi import Request from aidial_adapter_openai.app_config import ApplicationConfig from aidial_adapter_openai.completions import chat_completion as completion from aidial_adapter_openai.constant import ChatCompletionDeploymentType from aidial_adapter_openai.dalle3 import ( chat_completion as dalle3_chat_completion, ) from aidial_adapter_openai.databricks import ( chat_completion as databricks_chat_completion, ) from aidial_adapter_openai.dial_api.storage import create_file_storage from aidial_adapter_openai.gpt import gpt_chat_completion from aidial_adapter_openai.gpt4_multi_modal.chat_completion import ( gpt4_vision_chat_completion, gpt4o_chat_completion, ) from aidial_adapter_openai.mistral import ( chat_completion as mistral_chat_completion, ) from aidial_adapter_openai.utils.auth import get_credentials from aidial_adapter_openai.utils.image_tokenizer import get_image_tokenizer from aidial_adapter_openai.utils.parsers import completions_parser, parse_body from aidial_adapter_openai.utils.request import ( get_api_version, get_request_app_config, ) from aidial_adapter_openai.utils.streaming import create_server_response from aidial_adapter_openai.utils.tokenizer import ( MultiModalTokenizer, PlainTextTokenizer, ) async def call_chat_completion( deployment_id: str, data: dict, is_stream: bool, request: Request, app_config: ApplicationConfig, ): # Azure OpenAI deployments ignore "model" request field, # since the deployment id is already encoded in the endpoint path. # This is not the case for non-Azure OpenAI deployments, so # they require the "model" field to be set. # However, openai==1.33.0 requires the "model" field for **both** # Azure and non-Azure deployments. # Therefore, we provide the "model" field for all deployments here. # The same goes for /embeddings endpoint. data["model"] = deployment_id creds = await get_credentials(request) api_version = get_api_version(request) upstream_endpoint = request.headers["X-UPSTREAM-ENDPOINT"] if completions_endpoint := completions_parser.parse(upstream_endpoint): return await completion( data, completions_endpoint, creds, api_version, deployment_id, app_config, ) deployment_type = app_config.get_chat_completion_deployment_type( deployment_id ) match deployment_type: case ChatCompletionDeploymentType.DALLE3: storage = create_file_storage("images", request.headers) return await dalle3_chat_completion( data, upstream_endpoint, creds, is_stream, storage, app_config.DALLE3_AZURE_API_VERSION, ) case ChatCompletionDeploymentType.MISTRAL: return await mistral_chat_completion(data, upstream_endpoint, creds) case ChatCompletionDeploymentType.DATABRICKS: return await databricks_chat_completion( data, upstream_endpoint, creds ) case ChatCompletionDeploymentType.GPT4_VISION: tokenizer = MultiModalTokenizer( "gpt-4", get_image_tokenizer(deployment_type) ) return await gpt4_vision_chat_completion( data, deployment_id, upstream_endpoint, creds, is_stream, create_file_storage("images", request.headers), api_version, tokenizer, app_config.ELIMINATE_EMPTY_CHOICES, ) case ( ChatCompletionDeploymentType.GPT4O | ChatCompletionDeploymentType.GPT4O_MINI ): tokenizer = MultiModalTokenizer( app_config.MODEL_ALIASES.get(deployment_id, deployment_id), get_image_tokenizer(deployment_type), ) return await gpt4o_chat_completion( data, deployment_id, upstream_endpoint, creds, is_stream, create_file_storage("images", request.headers), api_version, tokenizer, app_config.ELIMINATE_EMPTY_CHOICES, ) case ChatCompletionDeploymentType.GPT_TEXT_ONLY: tokenizer = PlainTextTokenizer( model=app_config.MODEL_ALIASES.get(deployment_id, deployment_id) ) return await gpt_chat_completion( data, deployment_id, upstream_endpoint, creds, api_version, tokenizer, app_config.ELIMINATE_EMPTY_CHOICES, ) case _: assert_never(deployment_type) async def chat_completion(deployment_id: str, request: Request): app_config = get_request_app_config(request) data = await parse_body(request) is_stream = bool(data.get("stream")) emulate_streaming = ( deployment_id in app_config.NON_STREAMING_DEPLOYMENTS and is_stream ) if emulate_streaming: data["stream"] = False return create_server_response( emulate_streaming, await call_chat_completion( deployment_id, data, is_stream, request, app_config ), )