aidial_adapter_vertexai/chat/tools.py (245 lines of code) (raw):

from typing import Dict, List, Literal, Self, Tuple, assert_never, cast from aidial_sdk.chat_completion import ( Function, FunctionChoice, Message, Role, ToolChoice, ) from aidial_sdk.chat_completion.request import AzureChatCompletionRequest, Tool from google.genai.types import ( FunctionCallingConfigDict as GenAIFunctionCallingConfig, ) from google.genai.types import ( FunctionDeclarationDict as GenAIFunctionDeclaration, ) from google.genai.types import SchemaDict as GenAISchema from google.genai.types import ToolConfigDict as GenAIToolConfig from google.genai.types import ToolDict as GenAITool from pydantic.v1 import BaseModel from vertexai.preview.generative_models import ( FunctionDeclaration as GeminiFunction, ) from vertexai.preview.generative_models import Tool as GeminiTool from vertexai.preview.generative_models import ToolConfig as GeminiToolConfig from aidial_adapter_vertexai.chat.errors import ValidationError FunctionCallingConfig = GeminiToolConfig.FunctionCallingConfig class ToolsConfig(BaseModel): functions: List[Function] """ List of functions/tools. """ required: bool """ True forces the model to call one of the available functions. False allows the model to pick between generating a message or calling one or more tools/functions. """ tool_ids: Dict[str, str] | None """ Mapping from tool call IDs to corresponding tool names. None means that functions are used, not tools. """ @property def is_tool(self) -> bool: return self.tool_ids is not None def not_supported(self) -> None: if self.functions: if self.is_tool: raise ValidationError("The tools aren't supported") else: raise ValidationError("The functions aren't supported") def create_fresh_tool_call_id(self, tool_name: str) -> str: if self.tool_ids is None: raise ValidationError("Function are used, but requested tool id") idx = 1 while True: id = f"{tool_name}_{idx}" if id not in self.tool_ids: self.tool_ids[id] = tool_name return id idx += 1 def get_tool_name(self, tool_call_id: str) -> str: if self.tool_ids is None: raise ValidationError("Function are used, but requested tool name") tool_name = self.tool_ids.get(tool_call_id) if tool_name is None: raise ValidationError(f"Tool call ID not found: {self.tool_ids}") return tool_name @staticmethod def filter_functions( function_call: Literal["auto", "none"] | FunctionChoice, functions: List[Function], ) -> Tuple[bool, List[Function]]: match function_call: case "none": return False, [] case "auto": return False, functions case FunctionChoice(name=name): new_functions = [ func for func in functions if func.name == name ] if not new_functions: raise ValidationError( f"Function {name!r} is not on the list of available functions" ) return True, new_functions case _: assert_never(function_call) @staticmethod def tool_choice_to_function_call( tool_choice: Literal["auto", "none"] | ToolChoice | None, ) -> Literal["auto", "none"] | FunctionChoice | None: match tool_choice: case ToolChoice(function=FunctionChoice(name=name)): return FunctionChoice(name=name) case _: return tool_choice @classmethod def noop(cls) -> Self: return cls(functions=[], required=False, tool_ids=None) def is_empty(self) -> bool: return not self.functions @classmethod def from_request(cls, request: AzureChatCompletionRequest) -> Self: validate_messages(request) if request.functions is not None: functions = request.functions function_call = request.function_call tool_ids = None elif request.tools is not None: functions = [ tool.function for tool in request.tools if isinstance(tool, Tool) ] function_call = ToolsConfig.tool_choice_to_function_call( request.tool_choice ) tool_ids = collect_tool_ids(request.messages) else: functions = [] function_call = None tool_ids = None if function_call is None: function_call = "auto" if functions else "none" required, selected = ToolsConfig.filter_functions( function_call, functions ) return cls(functions=selected, required=required, tool_ids=tool_ids) def to_gemini_tools(self) -> List[GeminiTool]: if not self.functions: return [] return [ GeminiTool( function_declarations=[ GeminiFunction( name=func.name, parameters=func.parameters or {"type": "object", "properties": {}}, description=func.description, ) for func in self.functions ] ) ] def to_gemini_tool_config(self) -> GeminiToolConfig | None: if not self.functions: return None if self.required: return GeminiToolConfig( function_calling_config=FunctionCallingConfig( mode=FunctionCallingConfig.Mode.ANY, allowed_function_names=[ func.name for func in self.functions ], ) ) else: return GeminiToolConfig( function_calling_config=FunctionCallingConfig( mode=FunctionCallingConfig.Mode.AUTO ) ) def to_gemini_genai_tools(self) -> List[GenAITool]: if not self.functions: return [] return [ GenAITool( function_declarations=[ GenAIFunctionDeclaration( name=func.name, parameters=( _convert_genai_function_parameters(func.parameters) if func.parameters else GenAISchema(type="OBJECT", properties={}) ), description=func.description, ) for func in self.functions ] ) ] def to_gemini_genai_tool_config(self) -> GenAIToolConfig | None: if not self.functions: return None if self.required: return GenAIToolConfig( function_calling_config=GenAIFunctionCallingConfig( mode="ANY", allowed_function_names=[ func.name for func in self.functions ], ) ) else: return GenAIToolConfig( function_calling_config=GenAIFunctionCallingConfig( mode="AUTO", ) ) def validate_messages(request: AzureChatCompletionRequest) -> None: decl_tools = request.tools is not None decl_functions = request.functions is not None if decl_functions and decl_tools: raise ValidationError("Both functions and tools are not allowed") for message in request.messages: if message.role == Role.ASSISTANT: use_tools = message.tool_calls is not None if use_tools and not decl_tools: raise ValidationError( "Assistant message uses tools, but tools are not declared" ) use_functions = message.function_call is not None if use_functions and not decl_functions: raise ValidationError( "Assistant message uses functions, but functions are not declared" ) if message.role == Role.FUNCTION: if not decl_functions: raise ValidationError( "Function message is used, but functions are not declared" ) if message.role == Role.TOOL: if not decl_tools: raise ValidationError( "Tool message is used, but tools are not declared" ) def collect_tool_ids(messages: List[Message]) -> Dict[str, str]: ret: Dict[str, str] = {} for message in messages: if message.role == Role.ASSISTANT and message.tool_calls is not None: for tool_call in message.tool_calls: ret[tool_call.id] = tool_call.function.name return ret def _convert_genai_function_parameters(function_schema: dict) -> GenAISchema: def _convert_schema(schema: dict | str | list): if not isinstance(schema, dict): return schema genai_schema = {} for field, value in schema.items(): if field == "type": # GenAI function parameters should have types in uppercase genai_schema[field] = value.upper() elif isinstance(value, str): genai_schema[field] = value elif isinstance(value, list): genai_schema[field] = [_convert_schema(item) for item in value] elif isinstance(value, dict): genai_schema[field] = { key: _convert_schema(value) for key, value in value.items() } else: raise ValueError( f"Failed to convert function declaration to Vertex format: {schema}" ) return genai_schema return cast(GenAISchema, _convert_schema(function_schema))