aidial_interceptors_sdk/chat_completion/adapter.py (99 lines of code) (raw):

import json import logging from typing import AsyncIterator, Type, cast from aidial_sdk.chat_completion import ChatCompletion as DialChatCompletion from aidial_sdk.chat_completion import Request as DialRequest from aidial_sdk.chat_completion import Response as DialResponse from aidial_sdk.chat_completion.chunks import DefaultChunk from aidial_sdk.exceptions import HTTPException as DialException from openai import AsyncStream from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from aidial_interceptors_sdk.chat_completion.annotated_value import ( AnnotatedException, Annotation, ) from aidial_interceptors_sdk.chat_completion.base import ( ChatCompletionInterceptor, RequestDict, ) from aidial_interceptors_sdk.dial_client import DialClient from aidial_interceptors_sdk.error import EarlyStreamExit from aidial_interceptors_sdk.utils._debug import debug_logging from aidial_interceptors_sdk.utils._exceptions import dial_exception_decorator from aidial_interceptors_sdk.utils._http_client import HTTPClientFactory from aidial_interceptors_sdk.utils._reflection import call_with_extra_body from aidial_interceptors_sdk.utils.streaming import ( block_response_to_streaming_chunk, map_stream, materialize_streaming_errors, singleton_stream, ) _log = logging.getLogger(__name__) def interceptor_to_chat_completion( cls: Type[ChatCompletionInterceptor], dial_url: str, client_factory: HTTPClientFactory, ) -> DialChatCompletion: class Impl(DialChatCompletion): @dial_exception_decorator async def chat_completion( self, request: DialRequest, response: DialResponse ) -> None: dial_client = await DialClient.create( dial_url=dial_url, api_key=request.api_key, api_version=request.api_version, authorization=request.jwt, headers=request.headers, client_factory=client_factory, ) interceptor = cls( dial_client=dial_client, response=response, **request.original_request.path_params, ) request_body = await request.original_request.json() request_body = await debug_logging("request")( interceptor.traverse_request )(request_body) try: await interceptor.on_stream_start() def call_upstream(context: Annotation, request: dict): return call_single_upstream(dial_client, context, request) async for value in await interceptor.call_upstreams( request_body, call_upstream ): if isinstance(value, AnnotatedException): await interceptor.on_stream_error(value) else: await interceptor.traverse_response_chunk(value) await interceptor.on_stream_end() except EarlyStreamExit: pass return Impl() async def call_single_upstream( dial_client: DialClient, context: Annotation, request: RequestDict ) -> AsyncIterator[dict | DialException]: response = cast( AsyncStream[ChatCompletionChunk] | ChatCompletion, await call_with_extra_body( dial_client.client.chat.completions.create, request ), ) if isinstance(response, ChatCompletion): resp = response.to_dict() if _log.isEnabledFor(logging.DEBUG): _log.debug(f"upstream response[{context}]: {json.dumps(resp)}") # Non-streaming mode: # Removing the default fields which are generated by # DIAL SDK automatically. # It also means that these fields aren't proxied from the upstream. # They are recreated on each interceptor call. # If the fields aren't removed, then they will be merged # recursively with the one generated by SDK and we will end up with # "object": "chat.completionchat.completionchat.completion" for key in DefaultChunk.__annotations__.keys(): resp.pop(key, None) chunk = block_response_to_streaming_chunk(resp) stream = singleton_stream(chunk) else: # Streaming mode: # No need to remove default fields, because # they will be automatically overridden by the default fields # generated by DIAL SDK, when each chunk is merged naively with # a default chunk. def on_upstream_chunk(chunk: ChatCompletionChunk) -> dict: d = chunk.to_dict() if _log.isEnabledFor(logging.DEBUG): _log.debug(f"upstream chunk[{context}]: {json.dumps(d)}") return d stream = map_stream(on_upstream_chunk, response) return materialize_streaming_errors(stream)