aidial_assistant/application/assistant_callback.py (85 lines of code) (raw):

from types import TracebackType from aidial_sdk.chat_completion import Status from aidial_sdk.chat_completion.choice import Choice from aidial_sdk.chat_completion.stage import Stage from typing_extensions import override from aidial_assistant.chain.callbacks.args_callback import ArgsCallback from aidial_assistant.chain.callbacks.chain_callback import ChainCallback from aidial_assistant.chain.callbacks.command_callback import CommandCallback from aidial_assistant.chain.callbacks.result_callback import ResultCallback from aidial_assistant.commands.base import ExecutionCallback, ResultObject from aidial_assistant.utils.state import Invocation class AssistantCommandCallback(CommandCallback): def __init__(self, stage: Stage, addon_name_mapping: dict[str, str]): self.stage = stage self.addon_name_mapping = addon_name_mapping self._args_callback = ArgsCallback(self._on_stage_name) @override def on_command(self, command: str): self._on_stage_name(self.addon_name_mapping.get(command, command)) @override def execution_callback(self) -> ExecutionCallback: return self._on_stage_content @override def args_callback(self) -> ArgsCallback: return ArgsCallback(self._on_stage_name) @override def on_result(self, result: ResultObject): # Result reported by plugin pass @override def on_error(self, error: BaseException): self.stage.append_content(f"\n{str(error)}") def _on_stage_name(self, chunk: str): self.stage.append_name(chunk) def _on_stage_content(self, chunk: str): self.stage.append_content(chunk) def __enter__(self): self.stage.__enter__() return self @override def __exit__( self, __exc_type: type[BaseException] | None, __exc_value: BaseException | None, __traceback: TracebackType | None, ): if __exc_value is not None: self.on_error(__exc_value) self.stage.__exit__(__exc_type, __exc_value, __traceback) class AssistantResultCallback(ResultCallback): def __init__(self, choice: Choice): self.choice = choice def on_result(self, chunk: str): self.choice.append_content(chunk) class AssistantChainCallback(ChainCallback): def __init__(self, choice: Choice, addon_name_mapping: dict[str, str]): self.choice = choice self.addon_name_mapping = addon_name_mapping self._invocations: list[Invocation] = [] self._invocation_index: int = -1 self._discarded_messages: int = 0 @override def command_callback(self) -> CommandCallback: return AssistantCommandCallback( self.choice.create_stage(), self.addon_name_mapping ) @override def on_state(self, request: str, response: str): self._invocation_index += 1 self._invocations.append( Invocation( index=self._invocation_index, request=request, response=response ) ) @override def result_callback(self) -> ResultCallback: return AssistantResultCallback(self.choice) @override def on_error(self, title: str, error: str): stage = self.choice.create_stage(title) stage.open() stage.append_content(f"Error: {error}\n") stage.close(Status.FAILED) @property def invocations(self) -> list[Invocation]: return self._invocations