aidial_adapter_dial/app.py (149 lines of code) (raw):
import json
import logging
from urllib.parse import urlparse
from aidial_sdk.exceptions import InvalidRequestError
from aidial_sdk.telemetry.init import init_telemetry
from aidial_sdk.telemetry.types import TelemetryConfig
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from openai import AsyncAzureOpenAI, AsyncStream, BaseModel
from openai.types import CreateEmbeddingResponse
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from aidial_adapter_dial.transformer import AttachmentTransformer
from aidial_adapter_dial.utils.dict import censor_ci_dict
from aidial_adapter_dial.utils.env import get_env
from aidial_adapter_dial.utils.exceptions import to_dial_exception
from aidial_adapter_dial.utils.http_client import get_http_client
from aidial_adapter_dial.utils.log_config import configure_loggers
from aidial_adapter_dial.utils.reflection import call_with_extra_body
from aidial_adapter_dial.utils.sse_stream import to_openai_sse_stream
from aidial_adapter_dial.utils.storage import FileStorage
from aidial_adapter_dial.utils.streaming import amap_stream, map_stream
app = FastAPI()
init_telemetry(app, TelemetryConfig())
configure_loggers()
log = logging.getLogger(__name__)
is_debug = log.isEnabledFor(logging.DEBUG)
UPSTREAM_KEY_HEADER = "X-UPSTREAM-KEY"
UPSTREAM_ENDPOINT_HEADER = "X-UPSTREAM-ENDPOINT"
LOCAL_DIAL_URL = get_env("DIAL_URL")
def get_hostname(url: str) -> str:
parsed_url = urlparse(url)
hostname = f"{parsed_url.scheme}://{parsed_url.netloc}"
return hostname
class AzureClient(BaseModel):
client: AsyncAzureOpenAI
attachment_transformer: AttachmentTransformer
class Config:
arbitrary_types_allowed = True
@classmethod
async def parse(cls, request: Request, endpoint_name: str) -> "AzureClient":
body = await request.json()
headers = request.headers.mutablecopy()
query_params = request.query_params
if is_debug:
log.debug(f"request.body: {body}")
secret_headers = ["api-key", "authorization", UPSTREAM_KEY_HEADER]
log.debug(
f"request.headers: {censor_ci_dict(headers, secret_headers)}"
)
log.debug(f"request.params: {query_params}")
local_dial_api_key = headers.get("api-key", None)
if not local_dial_api_key:
raise InvalidRequestError("The 'api-key' request header is missing")
upstream_endpoint = headers.get(UPSTREAM_ENDPOINT_HEADER, None)
if not upstream_endpoint:
raise InvalidRequestError(
f"The {UPSTREAM_ENDPOINT_HEADER!r} request header is missing"
)
remote_dial_url = get_hostname(upstream_endpoint)
remote_dial_api_key = headers.get(UPSTREAM_KEY_HEADER, None)
if not remote_dial_api_key:
if remote_dial_url != LOCAL_DIAL_URL:
raise InvalidRequestError(
f"Given that {UPSTREAM_KEY_HEADER!r} header is missing, "
f"it's expected that hostname of upstream endpoint ({upstream_endpoint!r}) is "
f"the same as the local DIAL URL ({LOCAL_DIAL_URL!r}) "
)
local_dial_api_key = request.headers.get("api-key")
if not local_dial_api_key:
raise InvalidRequestError(
"The 'api-key' request header is missing"
)
remote_dial_api_key = local_dial_api_key
endpoint_suffix = f"/{endpoint_name}"
if not upstream_endpoint.endswith(endpoint_suffix):
raise InvalidRequestError(
f"The {UPSTREAM_ENDPOINT_HEADER!r} request header must end with {endpoint_suffix!r}"
)
upstream_endpoint = upstream_endpoint.removesuffix(endpoint_suffix)
client = AsyncAzureOpenAI(
base_url=upstream_endpoint,
api_key=remote_dial_api_key,
# NOTE: defaulting missing api-version to an empty string, because
# 1. openai library doesn't allow for a missing api-version
# and a workaround for it would be a recreation of AsyncAzureOpenAI with a check disabled:
# https://gitlab.deltixhub.com/Deltix/openai-apps/dial-interceptor-example/-/blob/62760a4c7a7be740b1c2bc60f14a0a568f31a0bc/aidial_interceptor_example/utils/azure.py#L1-5
# 2. OpenAI adapter treats a missing api-version in the same way as an empty string and that's the only
# place where api-version has any meaning, so the query param modification is safe.
# https://github.com/epam/ai-dial-adapter-openai/blob/b462d1c26ce8f9d569b9c085a849206aad91becf/aidial_adapter_openai/app.py#L93
api_version=query_params.get("api-version") or "",
http_client=get_http_client(),
)
attachment_transformer = await AttachmentTransformer.create(
local_storage=FileStorage(
dial_url=LOCAL_DIAL_URL,
api_key=local_dial_api_key,
),
remote_storage=FileStorage(
dial_url=remote_dial_url,
api_key=remote_dial_api_key,
),
)
return cls(
client=client,
attachment_transformer=attachment_transformer,
)
@app.post("/embeddings")
@app.post("/openai/deployments/{deployment_id:path}/embeddings")
async def embeddings_proxy(request: Request):
body = await request.json()
az_client = await AzureClient.parse(request, "embeddings")
response: CreateEmbeddingResponse = await call_with_extra_body(
az_client.client.embeddings.create, body
)
return response.to_dict()
@app.post("/chat/completions")
@app.post("/openai/deployments/{deployment_id:path}/chat/completions")
async def chat_completions_proxy(request: Request):
az_client = await AzureClient.parse(request, "chat/completions")
transformer = az_client.attachment_transformer
body = await request.json()
body = await transformer.modify_request(body)
if is_debug:
log.debug(f"request.body transformed: {body}")
response: AsyncStream[ChatCompletionChunk] | ChatCompletion = (
await call_with_extra_body(
az_client.client.chat.completions.create, body
)
)
if isinstance(response, AsyncStream):
async def modify_chunk(chunk: dict) -> dict:
chunk = await transformer.modify_response_chunk(chunk)
if is_debug:
log.debug(f"chunk: {json.dumps(chunk)}")
return chunk
chunk_stream = map_stream(lambda obj: obj.to_dict(), response)
return StreamingResponse(
to_openai_sse_stream(amap_stream(modify_chunk, chunk_stream)),
media_type="text/event-stream",
)
else:
resp = response.to_dict()
resp = await transformer.modify_response(resp)
if is_debug:
log.debug(f"response: {json.dumps(resp)}")
return resp
@app.exception_handler(Exception)
def exception_handler(request: Request, e: Exception):
log.exception(f"caught exception: {type(e).__module__}.{type(e).__name__}")
dial_exception = to_dial_exception(e)
fastapi_response = dial_exception.to_fastapi_response()
return fastapi_response
@app.get("/health")
def health():
return {"status": "ok"}