aidial_adapter_vertexai/chat/gemini/conversation_factory.py (133 lines of code) (raw):
import json
from abc import ABC, abstractmethod
from typing import Generic, List, TypeVar, assert_never
from aidial_sdk.chat_completion.request import Role
from google.genai.types import Content as GenAIContent
from google.genai.types import Part as GenAIPart
from pydantic.v1 import BaseModel
from vertexai.preview.generative_models import ChatSession, Content, Part
from aidial_adapter_vertexai.chat.errors import ValidationError
PartT = TypeVar("PartT")
ContentT = TypeVar("ContentT")
GeminiConversationT = TypeVar("GeminiConversationT")
class ConversationFactoryBase(
ABC, Generic[PartT, ContentT, GeminiConversationT]
):
@abstractmethod
def create_multi_modal_part(self, data: bytes, mime_type: str) -> PartT: ...
@abstractmethod
def create_text_part(self, text: str) -> PartT: ...
@abstractmethod
def create_function_call_part(self, name: str, args: str) -> PartT: ...
@abstractmethod
def create_function_result_part(self, name: str, args: str) -> PartT: ...
@abstractmethod
def create_content(self, role: Role, parts: List[PartT]) -> ContentT: ...
@abstractmethod
def create_conversation(
self, system_instruction: List[PartT] | None, contents: List[ContentT]
) -> GeminiConversationT: ...
class GeminiConversationBase(BaseModel, Generic[PartT, ContentT]):
system_instruction: List[PartT] | None = None
contents: List[ContentT]
class Config:
arbitrary_types_allowed = True
class GeminiConversation(GeminiConversationBase[Part, Content]):
pass
class ConversationFactory(
ConversationFactoryBase[Part, Content, GeminiConversation]
):
@staticmethod
def _to_gemini_role(role: Role) -> str:
match role:
case Role.SYSTEM:
raise ValidationError(
"System messages other than the first system message are not allowed"
)
case Role.USER | Role.FUNCTION | Role.TOOL:
return ChatSession._USER_ROLE
case Role.ASSISTANT:
return ChatSession._MODEL_ROLE
case _:
assert_never(role)
def create_multi_modal_part(self, data: bytes, mime_type: str) -> Part:
return Part.from_data(data=data, mime_type=mime_type)
def create_text_part(self, text: str) -> Part:
return Part.from_text(text)
def create_function_call_part(self, name: str, args: str) -> Part:
try:
args = json.loads(args)
return Part.from_dict(
{"function_call": {"name": name, "args": args}}
)
except Exception:
raise ValidationError(
"Function call arguments must be a valid JSON"
)
def create_function_result_part(self, name: str, args: str) -> Part:
try:
args = json.loads(args)
except Exception:
args = args
if isinstance(args, dict):
return Part.from_function_response(name, args)
return Part.from_function_response(name, {"content": args})
def create_content(self, role: Role, parts: List[Part]) -> Content:
return Content(role=self._to_gemini_role(role), parts=parts)
def create_conversation(
self, system_instruction: List[Part] | None, contents: List[Content]
) -> GeminiConversation:
return GeminiConversation(
system_instruction=system_instruction, contents=contents
)
class GeminiGenAIConversation(GeminiConversationBase[GenAIPart, GenAIContent]):
pass
class GenAIConversationFactory(
ConversationFactoryBase[GenAIPart, GenAIContent, GeminiGenAIConversation]
):
@staticmethod
def to_gemini_genai_role(role: Role) -> str:
match role:
case Role.SYSTEM:
raise ValidationError(
"System messages other than the first system message are not allowed"
)
case Role.USER | Role.FUNCTION | Role.TOOL:
return "user"
case Role.ASSISTANT:
return "model"
case _:
assert_never(role)
def create_multi_modal_part(self, data: bytes, mime_type: str) -> GenAIPart:
return GenAIPart.from_bytes(data=data, mime_type=mime_type)
def create_text_part(self, text: str) -> GenAIPart:
return GenAIPart.from_text(text)
def create_function_call_part(self, name: str, args: str) -> GenAIPart:
try:
return GenAIPart.from_function_call(name, json.loads(args))
except Exception:
raise ValidationError(
"Function call arguments must be a valid JSON"
)
def create_function_result_part(self, name: str, args: str) -> GenAIPart:
try:
processed_args = json.loads(args)
except Exception:
processed_args = args
if isinstance(processed_args, dict):
return GenAIPart.from_function_response(name, processed_args)
return GenAIPart.from_function_response(
name, {"output": processed_args}
)
def create_content(
self, role: Role, parts: List[GenAIPart]
) -> GenAIContent:
return GenAIContent(role=self.to_gemini_genai_role(role), parts=parts)
def create_conversation(
self,
system_instruction: List[GenAIPart] | None,
contents: List[GenAIContent],
) -> GeminiGenAIConversation:
return GeminiGenAIConversation(
system_instruction=system_instruction, contents=contents
)