aidial_adapter_bedrock/llm/tools/tools_config.py (152 lines of code) (raw):
from enum import Enum
from typing import Dict, List, Literal, Self, Tuple, assert_never
from aidial_sdk.chat_completion import (
Function,
FunctionChoice,
Message,
Role,
Tool,
ToolChoice,
)
from aidial_sdk.chat_completion.request import (
AzureChatCompletionRequest,
StaticTool,
)
from pydantic import BaseModel
from aidial_adapter_bedrock.llm.errors import ValidationError
class ToolsMode(Enum):
TOOLS = "TOOLS"
"""
Functions are deprecated instrument, that came before tools
"""
FUNCTIONS = "FUNCTIONS"
class ToolsConfig(BaseModel):
functions: List[Function]
"""
List of functions/tools.
"""
required: bool
"""
True forces the model to call one of the available functions.
False allows the model to pick between generating a message or
calling one or more tools/functions.
"""
tool_ids: Dict[str, str] | None
"""
Mapping from tool call IDs to corresponding tool names.
None means that functions are used, not tools.
"""
@property
def tools_mode(self) -> ToolsMode:
if self.tool_ids is not None:
return ToolsMode.TOOLS
else:
return ToolsMode.FUNCTIONS
def not_supported(self) -> None:
if self.functions:
if self.tools_mode == ToolsMode.TOOLS:
raise ValidationError("The tools aren't supported")
else:
raise ValidationError("The functions aren't supported")
def create_fresh_tool_call_id(self, tool_name: str) -> str:
if self.tool_ids is None:
raise ValidationError("Function are used, but requested tool id")
idx = 1
while True:
id = f"{tool_name}_{idx}"
if id not in self.tool_ids:
self.tool_ids[id] = tool_name
return id
idx += 1
def get_tool_name(self, tool_call_id: str) -> str:
if self.tool_ids is None:
raise ValidationError("Function are used, but requested tool name")
tool_name = self.tool_ids.get(tool_call_id)
if tool_name is None:
raise ValidationError(f"Tool call ID not found: {self.tool_ids}")
return tool_name
@staticmethod
def filter_functions(
function_call: Literal["auto", "none"] | FunctionChoice,
functions: List[Function],
) -> Tuple[bool, List[Function]]:
match function_call:
case "none":
return False, []
case "auto":
return False, functions
case FunctionChoice(name=name):
new_functions = [
func for func in functions if func.name == name
]
if not new_functions:
raise ValidationError(
f"Function {name!r} is not on the list of available functions"
)
return True, new_functions
case _:
assert_never(function_call)
@staticmethod
def tool_choice_to_function_call(
tool_choice: Literal["auto", "none"] | ToolChoice | None,
) -> Literal["auto", "none"] | FunctionChoice | None:
match tool_choice:
case ToolChoice(function=FunctionChoice(name=name)):
return FunctionChoice(name=name)
case _:
return tool_choice
@staticmethod
def _get_function_from_tool(tool: Tool | StaticTool) -> Function:
if isinstance(tool, Tool):
return tool.function
elif isinstance(tool, StaticTool):
raise ValidationError("Static tools aren't supported")
else:
assert_never(tool)
@classmethod
def from_request(cls, request: AzureChatCompletionRequest) -> Self | None:
validate_messages(request)
if request.functions is not None:
functions = request.functions
function_call = request.function_call
tool_ids = None
elif request.tools is not None:
functions = [
ToolsConfig._get_function_from_tool(tool)
for tool in request.tools
]
function_call = ToolsConfig.tool_choice_to_function_call(
request.tool_choice
)
tool_ids = collect_tool_ids(request.messages)
else:
functions = []
function_call = None
tool_ids = None
if function_call is None:
function_call = "auto" if functions else "none"
required, selected = ToolsConfig.filter_functions(
function_call, functions
)
if selected == []:
return None
return cls(functions=selected, required=required, tool_ids=tool_ids)
def validate_messages(request: AzureChatCompletionRequest) -> None:
decl_tools = request.tools is not None
decl_functions = request.functions is not None
if decl_functions and decl_tools:
raise ValidationError("Both functions and tools are not allowed")
for message in request.messages:
if message.role == Role.ASSISTANT:
use_tools = message.tool_calls is not None
if use_tools and not decl_tools:
raise ValidationError(
"Assistant message uses tools, but tools are not declared"
)
use_functions = message.function_call is not None
if use_functions and not decl_functions:
raise ValidationError(
"Assistant message uses functions, but functions are not declared"
)
if message.role == Role.FUNCTION:
if not decl_functions:
raise ValidationError(
"Function message is used, but functions are not declared"
)
if message.role == Role.TOOL:
if not decl_tools:
raise ValidationError(
"Tool message is used, but tools are not declared"
)
def collect_tool_ids(messages: List[Message]) -> Dict[str, str]:
ret: Dict[str, str] = {}
for message in messages:
if message.role == Role.ASSISTANT and message.tool_calls is not None:
for tool_call in message.tool_calls:
ret[tool_call.id] = tool_call.function.name
return ret