aidial_adapter_vertexai/chat/gemini/inputs.py (154 lines of code) (raw):

from typing import Any, List, Tuple, assert_never from aidial_sdk.chat_completion import Message, Role from aidial_adapter_vertexai.chat.errors import ValidationError from aidial_adapter_vertexai.chat.gemini.conversation_factory import ( ConversationFactory, ConversationFactoryBase, GeminiConversation, GeminiConversationT, GeminiGenAIConversation, GenAIConversationFactory, PartT, ) from aidial_adapter_vertexai.chat.gemini.processor import ( AttachmentProcessors, AttachmentProcessorsBase, AttachmentProcessorsGenAI, ) from aidial_adapter_vertexai.chat.tools import ToolsConfig FunctionName = str FunctionArgs = str async def messages_to_gemini_conversation_base( conversation_factory: ConversationFactoryBase[ PartT, Any, GeminiConversationT ], processors: AttachmentProcessorsBase[PartT], tools: ToolsConfig, messages: List[Message], ) -> GeminiConversationT: gemini_messages = [ ( message.role, await _message_to_gemini_parts( processors, tools, message, conversation_factory ), ) for message in messages ] system_instruction, gemini_messages = separate_system_messages( gemini_messages ) contents = [ conversation_factory.create_content(role, parts) for role, parts in gemini_messages ] return conversation_factory.create_conversation( system_instruction, contents, ) async def messages_to_gemini_conversation( conversation_factory: ConversationFactory, processors: AttachmentProcessors, tools: ToolsConfig, messages: List[Message], ) -> GeminiConversation: return await messages_to_gemini_conversation_base( conversation_factory, processors, tools, messages, ) async def messages_to_gemini_genai_conversation( conversation_factory: GenAIConversationFactory, processors: AttachmentProcessorsGenAI, tools: ToolsConfig, messages: List[Message], ) -> GeminiGenAIConversation: return await messages_to_gemini_conversation_base( conversation_factory, processors, tools, messages, ) async def _message_to_gemini_parts( processors: AttachmentProcessorsBase[PartT], tools: ToolsConfig, message: Message, conversation_factory: ConversationFactoryBase, ) -> List[PartT]: content = message.content match message.role: case Role.SYSTEM: if content is None: raise ValidationError("System message content must be present") return await processors.process_message(message) case Role.USER: if not content: raise ValidationError("User message content must be present") return await processors.process_message(message) case Role.ASSISTANT: if message.function_call is not None: return [ conversation_factory.create_function_call_part( message.function_call.name, message.function_call.arguments, ) ] elif message.tool_calls is not None: return [ conversation_factory.create_function_call_part( call.function.name, call.function.arguments ) for call in message.tool_calls ] else: if not content: raise ValidationError( "Assistant message content must be present" ) return await processors.process_message(message) case Role.FUNCTION: if content is None: raise ValidationError( "Function message content must be present" ) if not isinstance(content, str): raise ValidationError( "Function message content must be a string" ) name = message.name if name is None: raise ValidationError("Function message name must be present") return [ conversation_factory.create_function_result_part(name, content) ] case Role.TOOL: if content is None: raise ValidationError("Tool message content must be present") if not isinstance(content, str): raise ValidationError("Tool message content must be a string") tool_call_id = message.tool_call_id if tool_call_id is None: raise ValidationError( "Tool message tool_call_id must be present" ) name = tools.get_tool_name(tool_call_id) return [ conversation_factory.create_function_result_part(name, content) ] case _: assert_never(message.role) def separate_system_messages( messages: List[Tuple[Role, List[PartT]]] ) -> Tuple[List[PartT] | None, List[Tuple[Role, List[PartT]]]]: """ Extract the leading system messages from the list of messages. """ if len(messages) == 0: return None, messages system_messages: List[PartT] = [] while messages: role, message = messages[0] if role == Role.SYSTEM: system_messages.extend(message) messages = messages[1:] else: break return system_messages or None, messages