aidial_adapter_vertexai/chat/static_tools.py (135 lines of code) (raw):

from abc import ABC, abstractmethod from enum import Enum from typing import Generic, List, Literal, NoReturn, Self, TypeVar from aidial_sdk.chat_completion.request import ( AzureChatCompletionRequest, StaticFunction, StaticTool, ) from google.genai.types import GoogleSearchDict as GenAIGoogleSearch from google.genai.types import ToolDict as GenAITool from pydantic.v1 import BaseModel, ConstrainedFloat, Field from pydantic.v1 import ValidationError as PydanticValidationError from pydantic.v1 import root_validator from vertexai.preview.generative_models import Tool as GeminiTool from aidial_adapter_vertexai.chat.errors import ValidationError class ToolName(str, Enum): # https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/grounding GOOGLE_SEARCH = "google_search" ToolT = TypeVar("ToolT") class StaticToolProcessor(ABC, Generic[ToolT]): @staticmethod @abstractmethod def parse_gemini_tools( static_function: StaticFunction, ) -> List[ToolT] | None: ... class DynamicThreshold(ConstrainedFloat): ge = 0 le = 1 class DynamicRetrievalConfig(BaseModel): class Config: extra = "forbid" allow_population_by_field_name = True mode: Literal["MODE_DYNAMIC", "MODE_UNSPECIFIED"] | None = None dynamic_threshold: DynamicThreshold | None = Field( None, alias="dynamicThreshold" ) @root_validator(pre=True) def check_dynamic_threshold(cls, values): if values.get("mode") == "MODE_UNSPECIFIED" and ( values.get("dynamic_threshold") is not None or values.get("dynamicThreshold") is not None ): raise ValidationError( "dynamic_threshold must be None when mode is MODE_UNSPECIFIED" ) return values class GoogleSearchConfig(BaseModel): class Config: extra = "forbid" allow_population_by_field_name = True dynamic_retrieval_config: DynamicRetrievalConfig | None = Field( None, alias="dynamicRetrievalConfig" ) class GoogleSearchGroundingTool(StaticToolProcessor[GeminiTool]): @staticmethod def parse_gemini_tools( static_function: StaticFunction, ) -> List[GeminiTool] | None: if static_function.name == ToolName.GOOGLE_SEARCH: google_search_config = GoogleSearchConfig( dynamicRetrievalConfig=None ) if static_function.configuration: try: google_search_config = GoogleSearchConfig.validate( static_function.configuration ) except PydanticValidationError: raise ValidationError( "Invalid configuration for Google search tool" ) return [ GeminiTool.from_dict( { "google_search_retrieval": google_search_config.dict( exclude_none=True ) } ) ] return None class GenAIGoogleSearchTool(StaticToolProcessor[GenAITool]): @staticmethod def parse_gemini_tools( static_function: StaticFunction, ) -> List[GenAITool] | None: if static_function.name == ToolName.GOOGLE_SEARCH: if static_function.configuration: raise ValidationError( "Model doesn't support configuration for Google search tool" ) return [GenAITool(google_search=GenAIGoogleSearch())] return None def unknown_tool_name( static_function: StaticFunction, ) -> NoReturn: raise ValidationError( f"Unsupported static function: {static_function.name}" ) class StaticToolsConfig(BaseModel): functions: List[StaticFunction] @classmethod def from_request(cls, request: AzureChatCompletionRequest) -> Self: if request.tools is None: return cls(functions=[]) return cls( functions=[ tool.static_function for tool in request.tools if isinstance(tool, StaticTool) ] ) @classmethod def noop(cls) -> Self: return cls(functions=[]) def to_gemini_tools(self) -> List[GeminiTool]: ret: List[GeminiTool] = [] for tool in self.functions: ret.extend( GoogleSearchGroundingTool.parse_gemini_tools(tool) or unknown_tool_name(tool) ) return ret def to_gemini_genai_tools(self) -> List[GenAITool]: ret: List[GenAITool] = [] for tool in self.functions: ret.extend( GenAIGoogleSearchTool.parse_gemini_tools(tool) or unknown_tool_name(tool) ) return ret def not_supported(self) -> None: if self.functions: raise ValidationError("Static tools aren't supported") def is_empty(self) -> bool: return not self.functions