aidial_interceptors_sdk/chat_completion/base.py (38 lines of code) (raw):
from typing import AsyncIterator, Awaitable, Callable
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_interceptors_sdk.chat_completion.annotated_value import (
AnnotatedException,
AnnotatedValue,
Annotation,
)
from aidial_interceptors_sdk.chat_completion.request_handler import (
RequestHandler,
)
from aidial_interceptors_sdk.chat_completion.response_handler import (
ResponseHandler,
)
from aidial_interceptors_sdk.dial_client import DialClient
from aidial_interceptors_sdk.utils.streaming import annotate_stream
RequestDict = dict
class ChatCompletionInterceptor(RequestHandler, ResponseHandler):
dial_client: DialClient
async def call_upstreams(
self,
request: RequestDict,
call_upstream: Callable[
[Annotation, RequestDict],
Awaitable[AsyncIterator[dict | DialException]],
],
) -> AsyncIterator[AnnotatedValue]:
annotation = None
return annotate_stream(
annotation, await call_upstream(annotation, request)
)
async def on_stream_start(self) -> None:
# TODO: it's probably worth to put all the chunks
# generated by this method into a separate list.
# And then merge them all with the first incoming chunk.
# Otherwise, we may end up with choice being open *before*
# its "assistant" role is reported.
pass
async def on_stream_error(self, error: AnnotatedException) -> None:
raise error.error
async def on_stream_end(self) -> None:
# TODO: it's probably worth to withhold the last chunk generated by
# on_stream_chunk and merge it with all the chunks reported by on_stream_end.
pass
class ChatCompletionNoOpInterceptor(ChatCompletionInterceptor):
pass