aidial_assistant/tools_chain/tools_chain.py (229 lines of code) (raw):

import json from typing import Any, Tuple, cast from openai import BadRequestError from openai.types.chat import ( ChatCompletionMessageParam, ChatCompletionMessageToolCallParam, ChatCompletionToolParam, ) from openai.types.chat.chat_completion_message_tool_call_param import Function from aidial_assistant.chain.callbacks.chain_callback import ChainCallback from aidial_assistant.chain.callbacks.command_callback import CommandCallback from aidial_assistant.chain.command_chain import ( CommandConstructor, LimitExceededException, ModelRequestLimiter, ) from aidial_assistant.chain.command_result import ( CommandInvocation, CommandResult, Commands, Responses, Status, commands_to_text, responses_to_text, ) from aidial_assistant.chain.history import MessageScope, ScopedMessage from aidial_assistant.chain.model_response_reader import ( AssistantProtocolException, ) from aidial_assistant.commands.base import Command from aidial_assistant.model.model_client import ( ExtraResultsCallback, ModelClient, ) from aidial_assistant.utils.exceptions import RequestParameterValidationError from aidial_assistant.utils.open_ai import tool_calls_message, tool_message def convert_commands_to_tools( scoped_messages: list[ScopedMessage], ) -> list[ChatCompletionMessageParam]: messages: list[ChatCompletionMessageParam] = [] next_tool_id: int = 0 last_call_count: int = 0 for scoped_message in scoped_messages: message = scoped_message.message if scoped_message.scope == MessageScope.INTERNAL: content = cast(str, message.get("content")) if not content: raise RequestParameterValidationError( "State is broken. Content cannot be empty.", param="messages", ) if message["role"] == "assistant": commands: Commands = json.loads(content) messages.append( tool_calls_message( [ ChatCompletionMessageToolCallParam( id=str(next_tool_id + index), function=Function( name=command["command"], arguments=json.dumps(command["arguments"]), ), type="function", ) for index, command in enumerate( commands["commands"] ) ], ) ) last_call_count = len(commands["commands"]) next_tool_id += last_call_count elif message["role"] == "user": responses: Responses = json.loads(content) response_count = len(responses["responses"]) if response_count != last_call_count: raise RequestParameterValidationError( f"Expected {last_call_count} responses, but got {response_count}.", param="messages", ) first_tool_id = next_tool_id - last_call_count messages.extend( [ tool_message( content=response["response"], tool_call_id=str(first_tool_id + index), ) for index, response in enumerate(responses["responses"]) ] ) else: messages.append(scoped_message.message) return messages def _publish_command( command_callback: CommandCallback, name: str, arguments: str ): command_callback.on_command(name) args_callback = command_callback.args_callback() args_callback.on_args_start() args_callback.on_args_chunk(arguments) args_callback.on_args_end() CommandTool = Tuple[CommandConstructor, ChatCompletionToolParam] CommandToolDict = dict[str, CommandTool] class ToolCallsCallback(ExtraResultsCallback): def __init__(self): self.tool_calls: list[ChatCompletionMessageToolCallParam] = [] def on_tool_calls( self, tool_calls: list[ChatCompletionMessageToolCallParam] ): self.tool_calls = tool_calls class ToolsChain: def __init__( self, model: ModelClient, commands: CommandToolDict, max_completion_tokens: int | None = None, ): self.model = model self.commands = commands self.model_extra_args = ( {} if max_completion_tokens is None else {"max_tokens": max_completion_tokens} ) async def run_chat( self, messages: list[ChatCompletionMessageParam], callback: ChainCallback, model_request_limiter: ModelRequestLimiter | None = None, ): result_callback = callback.result_callback() last_message_block_length = 0 tools = [tool for _, tool in self.commands.values()] all_messages = messages.copy() while True: tool_calls_callback = ToolCallsCallback() try: if model_request_limiter: await model_request_limiter.verify_limit(all_messages) async for chunk in self.model.agenerate( all_messages, tool_calls_callback, tools=tools, **self.model_extra_args, ): result_callback.on_result(chunk) except (BadRequestError, LimitExceededException) as e: if ( last_message_block_length == 0 or isinstance(e, BadRequestError) and e.code == "429" ): raise # If the dialog size exceeds model context size then remove last message block # and try again without tools. all_messages = all_messages[:-last_message_block_length] async for chunk in self.model.agenerate( all_messages, tool_calls_callback ): result_callback.on_result(chunk) break if not tool_calls_callback.tool_calls: break previous_message_count = len(all_messages) all_messages.append( tool_calls_message( tool_calls_callback.tool_calls, ) ) all_messages += await self._run_tools( tool_calls_callback.tool_calls, callback ) last_message_block_length = ( len(all_messages) - previous_message_count ) def _create_command(self, name: str) -> Command: if name not in self.commands: raise AssistantProtocolException( f"The tool '{name}' is expected to be one of {list(self.commands.keys())}" ) command, _ = self.commands[name] return command() async def _run_tools( self, tool_calls: list[ChatCompletionMessageToolCallParam], callback: ChainCallback, ): commands: list[CommandInvocation] = [] command_results: list[CommandResult] = [] result_messages: list[ChatCompletionMessageParam] = [] for tool_call in tool_calls: function = tool_call["function"] name = function["name"] arguments: dict[str, Any] = json.loads(function["arguments"]) with callback.command_callback() as command_callback: _publish_command(command_callback, name, json.dumps(arguments)) command = self._create_command(name) result = await self._execute_command( command, arguments, command_callback, ) result_messages.append( tool_message( content=result["response"], tool_call_id=tool_call["id"], ) ) command_results.append(result) commands.append( CommandInvocation(command=name, arguments=arguments) ) callback.on_state( commands_to_text(commands), responses_to_text(command_results) ) return result_messages @staticmethod async def _execute_command( command: Command, args: dict[str, Any], command_callback: CommandCallback, ) -> CommandResult: try: result = await command.execute( args, command_callback.execution_callback() ) command_callback.on_result(result) return CommandResult(status=Status.SUCCESS, response=result.text) except Exception as e: command_callback.on_error(e) return CommandResult(status=Status.ERROR, response=str(e))