aidial_adapter_vertexai/chat/tools.py (245 lines of code) (raw):
from typing import Dict, List, Literal, Self, Tuple, assert_never, cast
from aidial_sdk.chat_completion import (
Function,
FunctionChoice,
Message,
Role,
ToolChoice,
)
from aidial_sdk.chat_completion.request import AzureChatCompletionRequest, Tool
from google.genai.types import (
FunctionCallingConfigDict as GenAIFunctionCallingConfig,
)
from google.genai.types import (
FunctionDeclarationDict as GenAIFunctionDeclaration,
)
from google.genai.types import SchemaDict as GenAISchema
from google.genai.types import ToolConfigDict as GenAIToolConfig
from google.genai.types import ToolDict as GenAITool
from pydantic.v1 import BaseModel
from vertexai.preview.generative_models import (
FunctionDeclaration as GeminiFunction,
)
from vertexai.preview.generative_models import Tool as GeminiTool
from vertexai.preview.generative_models import ToolConfig as GeminiToolConfig
from aidial_adapter_vertexai.chat.errors import ValidationError
FunctionCallingConfig = GeminiToolConfig.FunctionCallingConfig
class ToolsConfig(BaseModel):
functions: List[Function]
"""
List of functions/tools.
"""
required: bool
"""
True forces the model to call one of the available functions.
False allows the model to pick between generating a message or
calling one or more tools/functions.
"""
tool_ids: Dict[str, str] | None
"""
Mapping from tool call IDs to corresponding tool names.
None means that functions are used, not tools.
"""
@property
def is_tool(self) -> bool:
return self.tool_ids is not None
def not_supported(self) -> None:
if self.functions:
if self.is_tool:
raise ValidationError("The tools aren't supported")
else:
raise ValidationError("The functions aren't supported")
def create_fresh_tool_call_id(self, tool_name: str) -> str:
if self.tool_ids is None:
raise ValidationError("Function are used, but requested tool id")
idx = 1
while True:
id = f"{tool_name}_{idx}"
if id not in self.tool_ids:
self.tool_ids[id] = tool_name
return id
idx += 1
def get_tool_name(self, tool_call_id: str) -> str:
if self.tool_ids is None:
raise ValidationError("Function are used, but requested tool name")
tool_name = self.tool_ids.get(tool_call_id)
if tool_name is None:
raise ValidationError(f"Tool call ID not found: {self.tool_ids}")
return tool_name
@staticmethod
def filter_functions(
function_call: Literal["auto", "none"] | FunctionChoice,
functions: List[Function],
) -> Tuple[bool, List[Function]]:
match function_call:
case "none":
return False, []
case "auto":
return False, functions
case FunctionChoice(name=name):
new_functions = [
func for func in functions if func.name == name
]
if not new_functions:
raise ValidationError(
f"Function {name!r} is not on the list of available functions"
)
return True, new_functions
case _:
assert_never(function_call)
@staticmethod
def tool_choice_to_function_call(
tool_choice: Literal["auto", "none"] | ToolChoice | None,
) -> Literal["auto", "none"] | FunctionChoice | None:
match tool_choice:
case ToolChoice(function=FunctionChoice(name=name)):
return FunctionChoice(name=name)
case _:
return tool_choice
@classmethod
def noop(cls) -> Self:
return cls(functions=[], required=False, tool_ids=None)
def is_empty(self) -> bool:
return not self.functions
@classmethod
def from_request(cls, request: AzureChatCompletionRequest) -> Self:
validate_messages(request)
if request.functions is not None:
functions = request.functions
function_call = request.function_call
tool_ids = None
elif request.tools is not None:
functions = [
tool.function
for tool in request.tools
if isinstance(tool, Tool)
]
function_call = ToolsConfig.tool_choice_to_function_call(
request.tool_choice
)
tool_ids = collect_tool_ids(request.messages)
else:
functions = []
function_call = None
tool_ids = None
if function_call is None:
function_call = "auto" if functions else "none"
required, selected = ToolsConfig.filter_functions(
function_call, functions
)
return cls(functions=selected, required=required, tool_ids=tool_ids)
def to_gemini_tools(self) -> List[GeminiTool]:
if not self.functions:
return []
return [
GeminiTool(
function_declarations=[
GeminiFunction(
name=func.name,
parameters=func.parameters
or {"type": "object", "properties": {}},
description=func.description,
)
for func in self.functions
]
)
]
def to_gemini_tool_config(self) -> GeminiToolConfig | None:
if not self.functions:
return None
if self.required:
return GeminiToolConfig(
function_calling_config=FunctionCallingConfig(
mode=FunctionCallingConfig.Mode.ANY,
allowed_function_names=[
func.name for func in self.functions
],
)
)
else:
return GeminiToolConfig(
function_calling_config=FunctionCallingConfig(
mode=FunctionCallingConfig.Mode.AUTO
)
)
def to_gemini_genai_tools(self) -> List[GenAITool]:
if not self.functions:
return []
return [
GenAITool(
function_declarations=[
GenAIFunctionDeclaration(
name=func.name,
parameters=(
_convert_genai_function_parameters(func.parameters)
if func.parameters
else GenAISchema(type="OBJECT", properties={})
),
description=func.description,
)
for func in self.functions
]
)
]
def to_gemini_genai_tool_config(self) -> GenAIToolConfig | None:
if not self.functions:
return None
if self.required:
return GenAIToolConfig(
function_calling_config=GenAIFunctionCallingConfig(
mode="ANY",
allowed_function_names=[
func.name for func in self.functions
],
)
)
else:
return GenAIToolConfig(
function_calling_config=GenAIFunctionCallingConfig(
mode="AUTO",
)
)
def validate_messages(request: AzureChatCompletionRequest) -> None:
decl_tools = request.tools is not None
decl_functions = request.functions is not None
if decl_functions and decl_tools:
raise ValidationError("Both functions and tools are not allowed")
for message in request.messages:
if message.role == Role.ASSISTANT:
use_tools = message.tool_calls is not None
if use_tools and not decl_tools:
raise ValidationError(
"Assistant message uses tools, but tools are not declared"
)
use_functions = message.function_call is not None
if use_functions and not decl_functions:
raise ValidationError(
"Assistant message uses functions, but functions are not declared"
)
if message.role == Role.FUNCTION:
if not decl_functions:
raise ValidationError(
"Function message is used, but functions are not declared"
)
if message.role == Role.TOOL:
if not decl_tools:
raise ValidationError(
"Tool message is used, but tools are not declared"
)
def collect_tool_ids(messages: List[Message]) -> Dict[str, str]:
ret: Dict[str, str] = {}
for message in messages:
if message.role == Role.ASSISTANT and message.tool_calls is not None:
for tool_call in message.tool_calls:
ret[tool_call.id] = tool_call.function.name
return ret
def _convert_genai_function_parameters(function_schema: dict) -> GenAISchema:
def _convert_schema(schema: dict | str | list):
if not isinstance(schema, dict):
return schema
genai_schema = {}
for field, value in schema.items():
if field == "type":
# GenAI function parameters should have types in uppercase
genai_schema[field] = value.upper()
elif isinstance(value, str):
genai_schema[field] = value
elif isinstance(value, list):
genai_schema[field] = [_convert_schema(item) for item in value]
elif isinstance(value, dict):
genai_schema[field] = {
key: _convert_schema(value) for key, value in value.items()
}
else:
raise ValueError(
f"Failed to convert function declaration to Vertex format: {schema}"
)
return genai_schema
return cast(GenAISchema, _convert_schema(function_schema))