aidial_interceptors_sdk/chat_completion/response_handler.py (96 lines of code) (raw):
from typing import Dict, List
from aidial_sdk.chat_completion import Response
from aidial_sdk.chat_completion.chunks import BaseChunk
from aidial_sdk.pydantic_v1 import PrivateAttr
from aidial_interceptors_sdk.chat_completion.annotated_value import (
AnnotatedChunk,
)
from aidial_interceptors_sdk.chat_completion.element_path import (
ChoiceContext,
ElementPath,
)
from aidial_interceptors_sdk.chat_completion.helpers import (
traverse_dict_value,
traverse_list,
)
from aidial_interceptors_sdk.chat_completion.index_mapper import IndexMapper
from aidial_interceptors_sdk.chat_completion.response_message_handler import (
ResponseMessageHandler,
)
from aidial_interceptors_sdk.utils._dial_sdk import send_chunk_to_response
from aidial_interceptors_sdk.utils.not_given import NotGiven
class ResponseHandler(ResponseMessageHandler):
"""
Callbacks for handling chat completion responses.
1. the callbacks may mutate the response in place,
2. the callbacks are applied in bottom-up order (the order of callbacks in the class definition reflects the order of application),
3. a callback for a list element return a list of objects, which allows for adding or removing elements. Such callbacks have the type signature: `T -> List[T]`,
4. a callback for a dictionary value returns an optional dictionary, which allows for adding or removing keys. Such callbacks have the type signature: `(T | NotGiven | None) -> (T | NotGiven | None)`.
5. `on_response*` methods are expected to be overridden by subclasses.
6. `traverse_response*` methods are not expected to be overridden by subclasses.
The callbacks are intended to be used for:
1. collecting information about the response,
2. inplace modifying the response.
Avoid it creating new entities in the response, creation of deeply nested
entities is hard if possible.
E.g. adding a new stage in a response which doesn't have custom content is impossible,
because one would have to create "custom_content" field _before_
creating the "stages" field. Bottom-up application order of callbacks
doesn't allow for that.
"""
response: Response
# NOTE: `_stage_indices = {}` isn't going to work, since
# the underscored field `_stage_indices` will be shared across
# all instances of the class.
_stage_indices: Dict[int, IndexMapper[int]] = PrivateAttr({})
def _get_stage_index_mapper(self, choice_idx: int) -> IndexMapper[int]:
if choice_idx not in self._stage_indices:
self._stage_indices[choice_idx] = IndexMapper()
return self._stage_indices[choice_idx]
def reserve_stage_index(self, choice_idx: int) -> int:
return self._get_stage_index_mapper(choice_idx).reserve()
def send_chunk(self, chunk: BaseChunk | dict):
return send_chunk_to_response(self.response, chunk)
async def on_response_message(
self, path: ElementPath, message: dict | NotGiven | None
) -> dict | NotGiven | None:
return message
async def on_response_finish_reason(
self, path: ElementPath, finish_reason: str | NotGiven | None
) -> str | NotGiven | None:
return finish_reason
async def on_response_choice(
self, path: ElementPath, choice: dict
) -> List[dict] | dict:
return choice
async def on_response_choices(
self, choices: List[dict] | NotGiven | None
) -> List[dict] | NotGiven | None:
return choices
# TODO: add path to the signature and to the rest of similar methods
async def on_response_usage(
self, usage: dict | NotGiven | None
) -> dict | NotGiven | None:
return usage
async def on_stream_chunk(self, chunk: dict) -> None:
self.send_chunk(chunk)
async def traverse_response_chunk(self, ann_chunk: AnnotatedChunk) -> None:
r = ann_chunk.chunk
async def traverse_message(
path: ElementPath, message: dict | NotGiven | None
) -> dict | NotGiven | None:
if message is not None and not isinstance(message, NotGiven):
message = await self.traverse_response_message(path, message)
return await self.on_response_message(path, message)
async def traverse_choice(
path: ElementPath, choice: dict
) -> List[dict] | dict:
choice = await traverse_dict_value(
path, choice, "finish_reason", self.on_response_finish_reason
)
choice = await traverse_dict_value(
path, choice, "delta", traverse_message
)
return await self.on_response_choice(path, choice)
async def traverse_choices(
path: ElementPath, choices: List[dict] | NotGiven | None
) -> List[dict] | NotGiven | None:
def with_choice_ctx(choice_idx: int) -> ElementPath:
return path.with_choice_ctx(
ChoiceContext(
index=choice_idx,
stage_index_mapper=self._get_stage_index_mapper(
choice_idx
),
)
)
choices = await traverse_list(
with_choice_ctx, choices, traverse_choice
)
return await self.on_response_choices(choices)
async def traverse_response_usage(
path: ElementPath, usage: dict | NotGiven | None
) -> dict | NotGiven | None:
return await self.on_response_usage(usage)
path = ElementPath(response_ctx=ann_chunk.annotation)
r = await traverse_dict_value(path, r, "usage", traverse_response_usage)
r = await traverse_dict_value(path, r, "choices", traverse_choices)
await self.on_stream_chunk(r)