aidial_adapter_vertexai/chat/gemini/adapter/vertex_lib.py (220 lines of code) (raw):
from logging import DEBUG
from typing import AsyncIterator, Callable, List, Optional, assert_never, cast
from aidial_sdk.chat_completion import FinishReason, Message
from typing_extensions import override
from vertexai.preview.generative_models import (
Candidate,
GenerationResponse,
GenerativeModel,
Image,
Part,
)
from aidial_adapter_vertexai.chat.chat_completion_adapter import (
ChatCompletionAdapter,
)
from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.errors import UserError
from aidial_adapter_vertexai.chat.gemini.error import generate_with_retries
from aidial_adapter_vertexai.chat.gemini.finish_reason import (
to_openai_finish_reason,
)
from aidial_adapter_vertexai.chat.gemini.generation_config import (
create_generation_config,
)
from aidial_adapter_vertexai.chat.gemini.grounding import create_grounding
from aidial_adapter_vertexai.chat.gemini.output import (
create_attachments_from_citations,
create_function_calls,
set_usage,
)
from aidial_adapter_vertexai.chat.gemini.prompt.base import GeminiPrompt
from aidial_adapter_vertexai.chat.gemini.prompt.gemini_1_0_pro import (
Gemini_1_0_Pro_Prompt,
)
from aidial_adapter_vertexai.chat.gemini.prompt.gemini_1_0_pro_vision import (
Gemini_1_0_Pro_Vision_Prompt,
)
from aidial_adapter_vertexai.chat.gemini.prompt.gemini_1_5 import (
Gemini_1_5_Prompt,
)
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.deployments import (
ChatCompletionDeployment,
GeminiDeployment,
)
from aidial_adapter_vertexai.dial_api.request import ModelParameters
from aidial_adapter_vertexai.dial_api.storage import FileStorage
from aidial_adapter_vertexai.utils.json import json_dumps, json_dumps_short
from aidial_adapter_vertexai.utils.log_config import vertex_ai_logger as log
from aidial_adapter_vertexai.utils.timer import Timer
def _get_candidate_text_safe(candidate: Candidate) -> str | None:
# The text content of a candidate may be missing when function is called or
# when the generation was terminated with SAFETY finish reason.
try:
return candidate.text
except ValueError as e:
log.debug(f"The Candidate doesn't have text: {e}")
return None
class GeminiChatCompletionAdapter(ChatCompletionAdapter[GeminiPrompt]):
deployment: GeminiDeployment
def __init__(
self,
file_storage: Optional[FileStorage],
model_id: str,
deployment: GeminiDeployment,
):
self.file_storage = file_storage
self.model_id = model_id
self.deployment = deployment
@override
async def parse_prompt(
self,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> GeminiPrompt | UserError:
match self.deployment:
case ChatCompletionDeployment.GEMINI_PRO_1:
return await Gemini_1_0_Pro_Prompt.parse(
tools, static_tools, messages
)
case ChatCompletionDeployment.GEMINI_PRO_VISION_1:
return await Gemini_1_0_Pro_Vision_Prompt.parse(
self.file_storage, tools, static_tools, messages
)
case (
ChatCompletionDeployment.GEMINI_PRO_1_5_PREVIEW
| ChatCompletionDeployment.GEMINI_PRO_1_5_V1
| ChatCompletionDeployment.GEMINI_PRO_1_5_V2
| ChatCompletionDeployment.GEMINI_FLASH_1_5_V1
| ChatCompletionDeployment.GEMINI_FLASH_1_5_V2
):
return await Gemini_1_5_Prompt.parse(
self.file_storage, tools, static_tools, messages
)
case _:
assert_never(self.deployment)
def _get_model(
self,
*,
params: ModelParameters | None = None,
prompt: GeminiPrompt | None = None,
) -> GenerativeModel:
parameters = create_generation_config(params) if params else None
if prompt is not None:
tools = prompt.to_gemini_tools() or None
tool_config = prompt.tools.to_gemini_tool_config()
system_instruction = cast(
List[str | Part | Image] | None,
prompt.system_instruction,
)
else:
tools = None
tool_config = None
system_instruction = None
return GenerativeModel(
self.model_id,
generation_config=parameters,
tools=tools,
tool_config=tool_config,
system_instruction=system_instruction,
)
async def send_message_async(
self, params: ModelParameters, prompt: GeminiPrompt
) -> AsyncIterator[GenerationResponse]:
model = self._get_model(params=params, prompt=prompt)
contents = prompt.contents
if params.stream:
response = await model._generate_content_streaming_async(contents)
async for chunk in response:
yield chunk
else:
yield await model._generate_content_async(contents)
async def process_chunks(
self,
consumer: Consumer,
tools: ToolsConfig,
generator: Callable[[], AsyncIterator[GenerationResponse]],
) -> AsyncIterator[str]:
usage_metadata = None
is_grounding_added = False
async for chunk in generator():
if log.isEnabledFor(DEBUG):
chunk_str = json_dumps(chunk, excluded_keys=["safety_ratings"])
log.debug(f"response chunk: {chunk_str}")
if chunk.candidates:
candidate = chunk.candidates[0]
if (content := _get_candidate_text_safe(candidate)) is not None:
await consumer.append_content(content)
yield content
await create_function_calls(candidate, consumer, tools)
is_grounding_added |= await create_grounding(
candidate, consumer
)
await create_attachments_from_citations(candidate, consumer)
if openai_reason := to_openai_finish_reason(
candidate.finish_reason,
consumer.is_empty(),
):
await consumer.set_finish_reason(openai_reason)
if chunk.usage_metadata:
usage_metadata = chunk.usage_metadata
if chunk.prompt_feedback:
await consumer.set_finish_reason(FinishReason.CONTENT_FILTER)
if usage_metadata:
await set_usage(
usage_metadata,
consumer,
self.deployment,
is_grounding_added,
)
@override
async def chat(
self, params: ModelParameters, consumer: Consumer, prompt: GeminiPrompt
) -> None:
with Timer("predict timing: {time}", log.debug):
if log.isEnabledFor(DEBUG):
log.debug(
"predict request: "
+ json_dumps_short({"parameters": params, "prompt": prompt})
)
completion = ""
async for content in generate_with_retries(
lambda: self.process_chunks(
consumer,
prompt.tools,
lambda: self.send_message_async(params, prompt),
),
2,
):
completion += content
log.debug(f"predict response: {completion!r}")
@override
async def truncate_prompt(
self, prompt: GeminiPrompt, max_prompt_tokens: int
) -> TruncatedPrompt[GeminiPrompt]:
return await prompt.truncate(
tokenizer=self.count_prompt_tokens, user_limit=max_prompt_tokens
)
@override
async def count_prompt_tokens(self, prompt: GeminiPrompt) -> int:
with Timer("count_tokens[prompt] timing: {time}", log.debug):
resp = await self._get_model(prompt=prompt).count_tokens_async(
prompt.contents
)
log.debug(f"count_tokens[prompt] response: {json_dumps(resp)}")
return resp.total_tokens
@override
async def count_completion_tokens(self, string: str) -> int:
with Timer("count_tokens[completion] timing: {time}", log.debug):
resp = await self._get_model().count_tokens_async(string)
log.debug(f"count_tokens[completion] response: {json_dumps(resp)}")
return resp.total_tokens
@classmethod
async def create(
cls,
file_storage: Optional[FileStorage],
model_id: str,
deployment: GeminiDeployment,
) -> "GeminiChatCompletionAdapter":
return cls(file_storage, model_id, deployment)