aidial_assistant/commands/run_tool.py (89 lines of code) (raw):

from typing import Any from langchain_community.tools.openapi.utils.api_models import ( APIOperation, APIPropertyBase, ) from openai.types.chat import ChatCompletionToolParam from typing_extensions import override from aidial_assistant.commands.base import ( Command, ExecutionCallback, ResultObject, TextResult, get_required_field, ) from aidial_assistant.commands.open_api import OpenAPIChatCommand from aidial_assistant.commands.plugin_callback import PluginChainCallback from aidial_assistant.commands.run_plugin import PluginInfo from aidial_assistant.model.model_client import ( ModelClient, ReasonLengthException, ) from aidial_assistant.open_api.operation_selector import collect_operations from aidial_assistant.tools_chain.tools_chain import ( CommandTool, CommandToolDict, ToolsChain, ) from aidial_assistant.utils.open_ai import ( construct_tool, system_message, user_message, ) def _construct_property(p: APIPropertyBase) -> dict[str, Any]: parameter = { "type": p.type, "description": p.description, } return {k: v for k, v in parameter.items() if v is not None} def _construct_tool(op: APIOperation) -> ChatCompletionToolParam: properties = {} required = [] for p in op.properties: properties[p.name] = _construct_property(p) if p.required: required.append(p.name) if op.request_body is not None: for p in op.request_body.properties: properties[p.name] = _construct_property(p) if p.required: required.append(p.name) return construct_tool( op.operation_id, op.description or "", properties, required ) class RunTool(Command): def __init__( self, model: ModelClient, plugin: PluginInfo, max_completion_tokens: int ): self.model = model self.plugin = plugin self.max_completion_tokens = max_completion_tokens @staticmethod def token(): return "run-tool" @override async def execute( self, args: dict[str, Any], execution_callback: ExecutionCallback ) -> ResultObject: query = get_required_field(args, "query") ops = collect_operations( self.plugin.info.open_api, self.plugin.info.ai_plugin.api.url ) def create_command_tool(op: APIOperation) -> CommandTool: return lambda: OpenAPIChatCommand( op, self.plugin.auth ), _construct_tool(op) commands: CommandToolDict = { name: create_command_tool(op) for name, op in ops.items() } chain = ToolsChain(self.model, commands, self.max_completion_tokens) messages = [ system_message(self.plugin.info.ai_plugin.description_for_model), user_message(query), ] chain_callback = PluginChainCallback(execution_callback) try: await chain.run_chat(messages, chain_callback) except ReasonLengthException: pass return TextResult(chain_callback.result)