aidial_interceptors_sdk/chat_completion/adapter.py (99 lines of code) (raw):
import json
import logging
from typing import AsyncIterator, Type, cast
from aidial_sdk.chat_completion import ChatCompletion as DialChatCompletion
from aidial_sdk.chat_completion import Request as DialRequest
from aidial_sdk.chat_completion import Response as DialResponse
from aidial_sdk.chat_completion.chunks import DefaultChunk
from aidial_sdk.exceptions import HTTPException as DialException
from openai import AsyncStream
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from aidial_interceptors_sdk.chat_completion.annotated_value import (
AnnotatedException,
Annotation,
)
from aidial_interceptors_sdk.chat_completion.base import (
ChatCompletionInterceptor,
RequestDict,
)
from aidial_interceptors_sdk.dial_client import DialClient
from aidial_interceptors_sdk.error import EarlyStreamExit
from aidial_interceptors_sdk.utils._debug import debug_logging
from aidial_interceptors_sdk.utils._exceptions import dial_exception_decorator
from aidial_interceptors_sdk.utils._http_client import HTTPClientFactory
from aidial_interceptors_sdk.utils._reflection import call_with_extra_body
from aidial_interceptors_sdk.utils.streaming import (
block_response_to_streaming_chunk,
map_stream,
materialize_streaming_errors,
singleton_stream,
)
_log = logging.getLogger(__name__)
def interceptor_to_chat_completion(
cls: Type[ChatCompletionInterceptor],
dial_url: str,
client_factory: HTTPClientFactory,
) -> DialChatCompletion:
class Impl(DialChatCompletion):
@dial_exception_decorator
async def chat_completion(
self, request: DialRequest, response: DialResponse
) -> None:
dial_client = await DialClient.create(
dial_url=dial_url,
api_key=request.api_key,
api_version=request.api_version,
authorization=request.jwt,
headers=request.headers,
client_factory=client_factory,
)
interceptor = cls(
dial_client=dial_client,
response=response,
**request.original_request.path_params,
)
request_body = await request.original_request.json()
request_body = await debug_logging("request")(
interceptor.traverse_request
)(request_body)
try:
await interceptor.on_stream_start()
def call_upstream(context: Annotation, request: dict):
return call_single_upstream(dial_client, context, request)
async for value in await interceptor.call_upstreams(
request_body, call_upstream
):
if isinstance(value, AnnotatedException):
await interceptor.on_stream_error(value)
else:
await interceptor.traverse_response_chunk(value)
await interceptor.on_stream_end()
except EarlyStreamExit:
pass
return Impl()
async def call_single_upstream(
dial_client: DialClient, context: Annotation, request: RequestDict
) -> AsyncIterator[dict | DialException]:
response = cast(
AsyncStream[ChatCompletionChunk] | ChatCompletion,
await call_with_extra_body(
dial_client.client.chat.completions.create, request
),
)
if isinstance(response, ChatCompletion):
resp = response.to_dict()
if _log.isEnabledFor(logging.DEBUG):
_log.debug(f"upstream response[{context}]: {json.dumps(resp)}")
# Non-streaming mode:
# Removing the default fields which are generated by
# DIAL SDK automatically.
# It also means that these fields aren't proxied from the upstream.
# They are recreated on each interceptor call.
# If the fields aren't removed, then they will be merged
# recursively with the one generated by SDK and we will end up with
# "object": "chat.completionchat.completionchat.completion"
for key in DefaultChunk.__annotations__.keys():
resp.pop(key, None)
chunk = block_response_to_streaming_chunk(resp)
stream = singleton_stream(chunk)
else:
# Streaming mode:
# No need to remove default fields, because
# they will be automatically overridden by the default fields
# generated by DIAL SDK, when each chunk is merged naively with
# a default chunk.
def on_upstream_chunk(chunk: ChatCompletionChunk) -> dict:
d = chunk.to_dict()
if _log.isEnabledFor(logging.DEBUG):
_log.debug(f"upstream chunk[{context}]: {json.dumps(d)}")
return d
stream = map_stream(on_upstream_chunk, response)
return materialize_streaming_errors(stream)