aidial_adapter_bedrock/llm/tools/claude_emulator.py (99 lines of code) (raw):

""" Legacy tools support for Claude models: https://docs.anthropic.com/claude/docs/legacy-tool-use """ from typing import List, Optional from aidial_adapter_bedrock.llm.errors import ValidationError from aidial_adapter_bedrock.llm.message import ( AIFunctionCallMessage, AIRegularMessage, AIToolCallMessage, BaseMessage, HumanFunctionResultMessage, HumanRegularMessage, HumanToolResultMessage, SystemMessage, ToolMessage, ) from aidial_adapter_bedrock.llm.tools.call_recognizer import CallRecognizer from aidial_adapter_bedrock.llm.tools.claude_protocol import ( FUNC_END_TAG, FUNC_START_TAG, get_system_message, parse_call, print_function_call, print_function_call_result, print_tool_calls, print_tool_declarations, ) from aidial_adapter_bedrock.llm.tools.emulator import ToolsEmulator from aidial_adapter_bedrock.llm.tools.tools_config import ToolsConfig def convert_to_base_message( tool_config: Optional[ToolsConfig], msg: ToolMessage ) -> BaseMessage: match msg: case HumanToolResultMessage(id=id, content=content): if tool_config is None: raise ValidationError( "Tool message is used, but tools are not declared" ) name = tool_config.get_tool_name(id) return HumanRegularMessage( content=print_function_call_result(name=name, content=content) ) case HumanFunctionResultMessage(name=name, content=content): return HumanRegularMessage( content=print_function_call_result(name=name, content=content) ) case AIToolCallMessage(calls=calls): return AIRegularMessage(content=print_tool_calls(calls)) case AIFunctionCallMessage(call=call): return AIRegularMessage(content=print_function_call(call)) class Claude2_1_ToolsEmulator(ToolsEmulator): call_recognizer: CallRecognizer class Config: arbitrary_types_allowed = True @property def _tool_declarations(self) -> Optional[str]: return self.tool_config and print_tool_declarations( self.tool_config.functions ) def add_tool_declarations( self, messages: List[BaseMessage] ) -> List[BaseMessage]: if self._tool_declarations is None: return messages system_message = get_system_message(self._tool_declarations) # Concat with the user system message if len(messages) > 0 and isinstance(messages[0], SystemMessage): system_message += "\n" + messages[0].text_content messages = messages[1:] return [SystemMessage(content=system_message), *messages] def get_stop_sequences(self) -> List[str]: return [] if self._tool_declarations is None else [FUNC_END_TAG] def convert_to_base_messages( self, messages: List[BaseMessage | ToolMessage] ) -> List[BaseMessage]: return [ ( message if isinstance(message, BaseMessage) else convert_to_base_message(self.tool_config, message) ) for message in messages ] def recognize_call( self, content: str | None ) -> str | AIToolCallMessage | AIFunctionCallMessage | None: return ( self.call_recognizer.consume_chunk(content) if self.tool_config else content ) def legacy_tools_emulator( tool_config: Optional[ToolsConfig], ) -> ToolsEmulator: return Claude2_1_ToolsEmulator( tool_config=tool_config, call_recognizer=CallRecognizer( start_tag=FUNC_START_TAG, call_parser=lambda text: parse_call( tool_config, text + FUNC_END_TAG ), ), )