aidial_assistant/commands/plugin_callback.py (68 lines of code) (raw):

from types import TracebackType from typing import Callable from typing_extensions import override from aidial_assistant.chain.callbacks.args_callback import ArgsCallback from aidial_assistant.chain.callbacks.chain_callback import ChainCallback from aidial_assistant.chain.callbacks.command_callback import CommandCallback from aidial_assistant.chain.callbacks.result_callback import ResultCallback from aidial_assistant.commands.base import ( ExecutionCallback, ResultObject, ResultType, ) class PluginCommandCallback(CommandCallback): def __init__(self, callback: ExecutionCallback): self.callback = callback @override def on_command(self, command: str): self.callback(f"```javascript\n{command}") @override def args_callback(self) -> ArgsCallback: return ArgsCallback(self.callback) @override def execution_callback(self) -> ExecutionCallback: return self.callback @override def on_result(self, result: ResultObject): syntax = "json" if result.type == ResultType.JSON else "text" self.callback(f"\n```\n```{syntax}\n{result.text}\n```\n") @override def on_error(self, error: BaseException): self.callback(f"\n```\n```\nError: {str(error)}\n```\n") @override def __exit__( self, __exc_type: type[BaseException] | None, __exc_value: BaseException | None, __traceback: TracebackType | None, ): if __exc_value is not None: self.on_error(__exc_value) class PluginResultCallback(ResultCallback): def __init__(self, callback: Callable[[str], None]): self.callback = callback @override def on_result(self, chunk: str): self.callback(chunk) class PluginChainCallback(ChainCallback): def __init__(self, callback: Callable[[str], None]): self.callback = callback self._result = "" @override def command_callback(self) -> PluginCommandCallback: return PluginCommandCallback(self.callback) @override def result_callback(self) -> ResultCallback: return PluginResultCallback(self._on_result) @override def on_state(self, request: str, response: str): # Plugin state is not currently supported pass @override def on_error(self, title: str, error: str): pass @property def result(self) -> str: return self._result def _on_result(self, chunk): self._result += chunk self.callback(chunk)