aidial_adapter_bedrock/llm/tools/claude_protocol.py (139 lines of code) (raw):
import json
from typing import Dict, List, Literal, Optional
from aidial_sdk.chat_completion import Function, FunctionCall, ToolCall
from pydantic import BaseModel
from aidial_adapter_bedrock.llm.message import (
AIFunctionCallMessage,
AIToolCallMessage,
)
from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig, ToolsMode
from aidial_adapter_bedrock.utils.pydantic import ExtraForbidModel
from aidial_adapter_bedrock.utils.xml import parse_xml, tag, tag_nl
FUNC_TAG_NAME = "function_calls"
FUNC_START_TAG = f"<{FUNC_TAG_NAME}>"
FUNC_END_TAG = f"</{FUNC_TAG_NAME}>"
def get_system_message(tool_declarations: str) -> str:
return f"""
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Avoid showing the function calls and respective results to the user.
Here are the tools available:
{tool_declarations}
""".strip()
class ToolParameterProperties(ExtraForbidModel):
type: str
description: Optional[str]
default: Optional[str] = None
items: Optional["ToolParameterProperties"] = None
enum: Optional[List[str]] = None
# The title is allowed according to the JSON Schema, but not used
title: Optional[str] = None
class ToolParameters(BaseModel):
type: Literal["object"]
properties: Dict[str, ToolParameterProperties]
required: Optional[List[str]]
def _print_tool_parameter_properties(
props: ToolParameterProperties,
) -> list[str | None]:
return [
tag("type", props.type),
tag_nl(
"items",
(
_print_tool_parameter_properties(props.items)
if props.items
else None
),
),
tag("enum", ", ".join(props.enum) if props.enum else None),
tag("description", props.description),
tag("default", props.default),
]
def _print_tool_parameter(name: str, props: ToolParameterProperties) -> str:
return tag_nl(
"parameter",
[tag("name", name)] + _print_tool_parameter_properties(props),
)
def _print_tool_parameters(parameters: ToolParameters) -> str:
return tag_nl(
"parameters",
[
_print_tool_parameter(name, props)
for name, props in parameters.properties.items()
],
)
def _print_tool_declaration(function: Function) -> str:
return tag_nl(
"tool_description",
[
tag("tool_name", function.name),
tag("description", function.description),
_print_tool_parameters(
ToolParameters.parse_obj(function.parameters)
),
],
)
def print_tool_declarations(functions: List[Function]) -> str:
return tag_nl(
"tools", [_print_tool_declaration(function) for function in functions]
)
def _print_function_call_parameters(parameters: dict) -> str:
return tag_nl(
"parameters",
[tag(name, value) for name, value in parameters.items()],
)
def print_tool_calls(calls: List[ToolCall]) -> str:
return tag_nl(
FUNC_TAG_NAME,
[print_function_call(call.function) for call in calls],
)
def print_function_call(call: FunctionCall) -> str:
try:
arguments = json.loads(call.arguments)
except Exception:
raise Exception(
"Unable to parse function call arguments: it's not a valid JSON"
)
return tag_nl(
FUNC_TAG_NAME,
tag_nl(
"invoke",
[
tag("tool_name", call.name),
_print_function_call_parameters(arguments),
],
),
)
def _parse_function_call(text: str) -> FunctionCall:
start_index = text.find(FUNC_START_TAG)
if start_index == -1:
raise Exception(
f"Unable to parse function call, missing {FUNC_TAG_NAME!r} tag"
)
try:
dict = parse_xml(text[start_index:])
invocation = dict[FUNC_TAG_NAME]["invoke"]
tool_name = invocation["tool_name"]
parameters = invocation["parameters"]
except Exception:
raise Exception("Unable to parse function call")
return FunctionCall(name=tool_name, arguments=json.dumps(parameters))
def parse_call(
config: Optional[ToolsConfig], text: str
) -> AIToolCallMessage | AIFunctionCallMessage | None:
if config is None:
return None
call = _parse_function_call(text)
if config.tools_mode == ToolsMode.TOOLS:
id = config.create_fresh_tool_call_id(call.name)
tool_call = ToolCall(index=0, id=id, type="function", function=call)
return AIToolCallMessage(calls=[tool_call])
else:
return AIFunctionCallMessage(call=call)
def print_function_call_result(name: str, content: str) -> str:
return tag_nl(
"function_results",
[
tag_nl(
"result",
[
tag("tool_name", name),
tag_nl("stdout", content),
],
)
],
)