aidial_adapter_bedrock/llm/tools/tools_config.py (152 lines of code) (raw):

from enum import Enum from typing import Dict, List, Literal, Self, Tuple, assert_never from aidial_sdk.chat_completion import ( Function, FunctionChoice, Message, Role, Tool, ToolChoice, ) from aidial_sdk.chat_completion.request import ( AzureChatCompletionRequest, StaticTool, ) from pydantic import BaseModel from aidial_adapter_bedrock.llm.errors import ValidationError class ToolsMode(Enum): TOOLS = "TOOLS" """ Functions are deprecated instrument, that came before tools """ FUNCTIONS = "FUNCTIONS" class ToolsConfig(BaseModel): functions: List[Function] """ List of functions/tools. """ required: bool """ True forces the model to call one of the available functions. False allows the model to pick between generating a message or calling one or more tools/functions. """ tool_ids: Dict[str, str] | None """ Mapping from tool call IDs to corresponding tool names. None means that functions are used, not tools. """ @property def tools_mode(self) -> ToolsMode: if self.tool_ids is not None: return ToolsMode.TOOLS else: return ToolsMode.FUNCTIONS def not_supported(self) -> None: if self.functions: if self.tools_mode == ToolsMode.TOOLS: raise ValidationError("The tools aren't supported") else: raise ValidationError("The functions aren't supported") def create_fresh_tool_call_id(self, tool_name: str) -> str: if self.tool_ids is None: raise ValidationError("Function are used, but requested tool id") idx = 1 while True: id = f"{tool_name}_{idx}" if id not in self.tool_ids: self.tool_ids[id] = tool_name return id idx += 1 def get_tool_name(self, tool_call_id: str) -> str: if self.tool_ids is None: raise ValidationError("Function are used, but requested tool name") tool_name = self.tool_ids.get(tool_call_id) if tool_name is None: raise ValidationError(f"Tool call ID not found: {self.tool_ids}") return tool_name @staticmethod def filter_functions( function_call: Literal["auto", "none"] | FunctionChoice, functions: List[Function], ) -> Tuple[bool, List[Function]]: match function_call: case "none": return False, [] case "auto": return False, functions case FunctionChoice(name=name): new_functions = [ func for func in functions if func.name == name ] if not new_functions: raise ValidationError( f"Function {name!r} is not on the list of available functions" ) return True, new_functions case _: assert_never(function_call) @staticmethod def tool_choice_to_function_call( tool_choice: Literal["auto", "none"] | ToolChoice | None, ) -> Literal["auto", "none"] | FunctionChoice | None: match tool_choice: case ToolChoice(function=FunctionChoice(name=name)): return FunctionChoice(name=name) case _: return tool_choice @staticmethod def _get_function_from_tool(tool: Tool | StaticTool) -> Function: if isinstance(tool, Tool): return tool.function elif isinstance(tool, StaticTool): raise ValidationError("Static tools aren't supported") else: assert_never(tool) @classmethod def from_request(cls, request: AzureChatCompletionRequest) -> Self | None: validate_messages(request) if request.functions is not None: functions = request.functions function_call = request.function_call tool_ids = None elif request.tools is not None: functions = [ ToolsConfig._get_function_from_tool(tool) for tool in request.tools ] function_call = ToolsConfig.tool_choice_to_function_call( request.tool_choice ) tool_ids = collect_tool_ids(request.messages) else: functions = [] function_call = None tool_ids = None if function_call is None: function_call = "auto" if functions else "none" required, selected = ToolsConfig.filter_functions( function_call, functions ) if selected == []: return None return cls(functions=selected, required=required, tool_ids=tool_ids) def validate_messages(request: AzureChatCompletionRequest) -> None: decl_tools = request.tools is not None decl_functions = request.functions is not None if decl_functions and decl_tools: raise ValidationError("Both functions and tools are not allowed") for message in request.messages: if message.role == Role.ASSISTANT: use_tools = message.tool_calls is not None if use_tools and not decl_tools: raise ValidationError( "Assistant message uses tools, but tools are not declared" ) use_functions = message.function_call is not None if use_functions and not decl_functions: raise ValidationError( "Assistant message uses functions, but functions are not declared" ) if message.role == Role.FUNCTION: if not decl_functions: raise ValidationError( "Function message is used, but functions are not declared" ) if message.role == Role.TOOL: if not decl_tools: raise ValidationError( "Tool message is used, but tools are not declared" ) def collect_tool_ids(messages: List[Message]) -> Dict[str, str]: ret: Dict[str, str] = {} for message in messages: if message.role == Role.ASSISTANT and message.tool_calls is not None: for tool_call in message.tool_calls: ret[tool_call.id] = tool_call.function.name return ret