aidial_adapter_vertexai/chat/gemini/prompt/base.py (67 lines of code) (raw):

from abc import ABC from typing import Generic, List, Self, Set from google.genai.types import Content as GenAIContent from google.genai.types import Part as GenAIPart from pydantic.v1 import BaseModel, Field from vertexai.preview.generative_models import Content, Part from vertexai.preview.generative_models import Tool as GeminiTool from aidial_adapter_vertexai.chat.gemini.conversation_factory import ( ContentT, PartT, ) from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig from aidial_adapter_vertexai.chat.tools import ToolsConfig from aidial_adapter_vertexai.chat.truncate_prompt import TruncatablePrompt class GeminiBasePrompt( BaseModel, TruncatablePrompt, ABC, Generic[PartT, ContentT] ): system_instruction: List[PartT] | None = None contents: List[ContentT] tools: ToolsConfig = Field(default_factory=ToolsConfig.noop) static_tools: StaticToolsConfig = Field( default_factory=StaticToolsConfig.noop ) class Config: arbitrary_types_allowed = True @property def has_system_instruction(self) -> bool: return self.system_instruction is not None def is_required_message(self, index: int) -> bool: # Keep the system message... if self.has_system_instruction and index == 0: return True # ...and the last user message if index == len(self) - 1: return True return False def __len__(self) -> int: return int(self.has_system_instruction) + len(self.contents) def partition_messages(self) -> List[int]: n = len(self.contents) return ( [1] * self.has_system_instruction + [2] * (n // 2) + [1] * (n % 2) ) def select(self, indices: Set[int]) -> Self: system_instruction: List[PartT] | None = None contents: List[ContentT] = [] offset = 0 if self.has_system_instruction and 0 in indices: system_instruction = self.system_instruction offset += 1 for idx in range(len(self.contents)): if idx + offset in indices: contents.append(self.contents[idx]) if len(self.contents) - 1 + offset not in indices: raise RuntimeError("The last user prompt must not be omitted.") return self.__class__( system_instruction=system_instruction, contents=contents, tools=self.tools, static_tools=self.static_tools, ) def to_gemini_tools(self) -> List[GeminiTool]: regular_tools = self.tools.to_gemini_tools() static_tools = self.static_tools.to_gemini_tools() return regular_tools + static_tools class GeminiPrompt(GeminiBasePrompt[Part, Content]): pass class GeminiGenAIPrompt(GeminiBasePrompt[GenAIPart, GenAIContent]): pass