aidial_adapter_vertexai/chat/gemini/inputs.py (154 lines of code) (raw):
from typing import Any, List, Tuple, assert_never
from aidial_sdk.chat_completion import Message, Role
from aidial_adapter_vertexai.chat.errors import ValidationError
from aidial_adapter_vertexai.chat.gemini.conversation_factory import (
ConversationFactory,
ConversationFactoryBase,
GeminiConversation,
GeminiConversationT,
GeminiGenAIConversation,
GenAIConversationFactory,
PartT,
)
from aidial_adapter_vertexai.chat.gemini.processor import (
AttachmentProcessors,
AttachmentProcessorsBase,
AttachmentProcessorsGenAI,
)
from aidial_adapter_vertexai.chat.tools import ToolsConfig
FunctionName = str
FunctionArgs = str
async def messages_to_gemini_conversation_base(
conversation_factory: ConversationFactoryBase[
PartT, Any, GeminiConversationT
],
processors: AttachmentProcessorsBase[PartT],
tools: ToolsConfig,
messages: List[Message],
) -> GeminiConversationT:
gemini_messages = [
(
message.role,
await _message_to_gemini_parts(
processors, tools, message, conversation_factory
),
)
for message in messages
]
system_instruction, gemini_messages = separate_system_messages(
gemini_messages
)
contents = [
conversation_factory.create_content(role, parts)
for role, parts in gemini_messages
]
return conversation_factory.create_conversation(
system_instruction,
contents,
)
async def messages_to_gemini_conversation(
conversation_factory: ConversationFactory,
processors: AttachmentProcessors,
tools: ToolsConfig,
messages: List[Message],
) -> GeminiConversation:
return await messages_to_gemini_conversation_base(
conversation_factory,
processors,
tools,
messages,
)
async def messages_to_gemini_genai_conversation(
conversation_factory: GenAIConversationFactory,
processors: AttachmentProcessorsGenAI,
tools: ToolsConfig,
messages: List[Message],
) -> GeminiGenAIConversation:
return await messages_to_gemini_conversation_base(
conversation_factory,
processors,
tools,
messages,
)
async def _message_to_gemini_parts(
processors: AttachmentProcessorsBase[PartT],
tools: ToolsConfig,
message: Message,
conversation_factory: ConversationFactoryBase,
) -> List[PartT]:
content = message.content
match message.role:
case Role.SYSTEM:
if content is None:
raise ValidationError("System message content must be present")
return await processors.process_message(message)
case Role.USER:
if not content:
raise ValidationError("User message content must be present")
return await processors.process_message(message)
case Role.ASSISTANT:
if message.function_call is not None:
return [
conversation_factory.create_function_call_part(
message.function_call.name,
message.function_call.arguments,
)
]
elif message.tool_calls is not None:
return [
conversation_factory.create_function_call_part(
call.function.name, call.function.arguments
)
for call in message.tool_calls
]
else:
if not content:
raise ValidationError(
"Assistant message content must be present"
)
return await processors.process_message(message)
case Role.FUNCTION:
if content is None:
raise ValidationError(
"Function message content must be present"
)
if not isinstance(content, str):
raise ValidationError(
"Function message content must be a string"
)
name = message.name
if name is None:
raise ValidationError("Function message name must be present")
return [
conversation_factory.create_function_result_part(name, content)
]
case Role.TOOL:
if content is None:
raise ValidationError("Tool message content must be present")
if not isinstance(content, str):
raise ValidationError("Tool message content must be a string")
tool_call_id = message.tool_call_id
if tool_call_id is None:
raise ValidationError(
"Tool message tool_call_id must be present"
)
name = tools.get_tool_name(tool_call_id)
return [
conversation_factory.create_function_result_part(name, content)
]
case _:
assert_never(message.role)
def separate_system_messages(
messages: List[Tuple[Role, List[PartT]]]
) -> Tuple[List[PartT] | None, List[Tuple[Role, List[PartT]]]]:
"""
Extract the leading system messages from the list of messages.
"""
if len(messages) == 0:
return None, messages
system_messages: List[PartT] = []
while messages:
role, message = messages[0]
if role == Role.SYSTEM:
system_messages.extend(message)
messages = messages[1:]
else:
break
return system_messages or None, messages