aidial_adapter_vertexai/chat/gemini/adapter/genai_lib.py (210 lines of code) (raw):
from logging import DEBUG
from typing import AsyncIterator, Callable, List, Optional, assert_never
from aidial_sdk.chat_completion import FinishReason, Message, Stage
from aidial_sdk.exceptions import RuntimeServerError
from google.genai.client import Client as GenAIClient
from google.genai.types import (
GenerateContentResponse as GenAIGenerateContentResponse,
)
from typing_extensions import override
from aidial_adapter_vertexai.chat.chat_completion_adapter import (
ChatCompletionAdapter,
)
from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.errors import UserError
from aidial_adapter_vertexai.chat.gemini.error import generate_with_retries
from aidial_adapter_vertexai.chat.gemini.finish_reason import (
genai_to_openai_finish_reason,
)
from aidial_adapter_vertexai.chat.gemini.generation_config import (
create_genai_generation_config,
)
from aidial_adapter_vertexai.chat.gemini.grounding import create_grounding
from aidial_adapter_vertexai.chat.gemini.output import (
create_attachments_from_citations,
create_function_calls_from_genai,
set_usage,
)
from aidial_adapter_vertexai.chat.gemini.prompt.base import GeminiGenAIPrompt
from aidial_adapter_vertexai.chat.gemini.prompt.gemini_2 import Gemini_2_Prompt
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.deployments import (
ChatCompletionDeployment,
Gemini2Deployment,
)
from aidial_adapter_vertexai.dial_api.request import ModelParameters
from aidial_adapter_vertexai.dial_api.storage import FileStorage
from aidial_adapter_vertexai.utils.json import json_dumps, json_dumps_short
from aidial_adapter_vertexai.utils.log_config import vertex_ai_logger as log
from aidial_adapter_vertexai.utils.timer import Timer
class GeminiGenAIChatCompletionAdapter(
ChatCompletionAdapter[GeminiGenAIPrompt]
):
deployment: Gemini2Deployment
def __init__(
self,
client: GenAIClient,
file_storage: Optional[FileStorage],
model_id: str,
deployment: Gemini2Deployment,
):
self.file_storage = file_storage
self.model_id = model_id
self.deployment = deployment
self.client = client
@override
async def parse_prompt(
self,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> GeminiGenAIPrompt | UserError:
match self.deployment:
case (
ChatCompletionDeployment.GEMINI_2_0_EXPERIMENTAL_1206
| ChatCompletionDeployment.GEMINI_2_0_FLASH_EXP
| ChatCompletionDeployment.GEMINI_2_0_FLASH_THINKING_EXP_1219
):
return await Gemini_2_Prompt.parse(
self.file_storage, tools, static_tools, messages
)
case _:
assert_never(self.deployment)
async def send_message_async(
self, params: ModelParameters, prompt: GeminiGenAIPrompt
) -> AsyncIterator[GenAIGenerateContentResponse]:
generation_config = create_genai_generation_config(
params,
prompt.tools,
prompt.static_tools,
prompt.system_instruction,
)
if params.stream:
async for chunk in self.client.aio.models.generate_content_stream(
model=self.model_id,
contents=list(prompt.contents),
config=generation_config,
):
yield chunk
else:
yield await self.client.aio.models.generate_content(
model=self.model_id,
contents=list(prompt.contents),
config=generation_config,
)
async def process_chunks(
self,
consumer: Consumer,
tools: ToolsConfig,
generator: Callable[[], AsyncIterator[GenAIGenerateContentResponse]],
):
thinking_stage: Stage | None = None
usage_metadata = None
is_grounding_added = False
try:
async for chunk in generator():
if log.isEnabledFor(DEBUG):
chunk_str = json_dumps(chunk)
log.debug(f"response chunk: {chunk_str}")
if chunk.prompt_feedback:
await consumer.set_finish_reason(
FinishReason.CONTENT_FILTER
)
if chunk.usage_metadata:
usage_metadata = chunk.usage_metadata
if not chunk.candidates:
continue
candidate = chunk.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
await create_function_calls_from_genai(
part, consumer, tools
)
if part.thought and part.text:
if thinking_stage is None:
thinking_stage = await consumer.create_stage(
"Thought Process"
)
thinking_stage.open()
thinking_stage.append_content(part.text)
yield part.text
elif part.text:
await consumer.append_content(part.text)
yield part.text
is_grounding_added |= await create_grounding(
candidate, consumer
)
await create_attachments_from_citations(candidate, consumer)
if openai_reason := genai_to_openai_finish_reason(
candidate.finish_reason,
consumer.is_empty(),
):
await consumer.set_finish_reason(openai_reason)
finally:
if thinking_stage:
thinking_stage.close()
# It's possible that max tokens will be reached during the thinking stage
# and there will be no content in response.
# And set_usage will fail with 'Trying to set "usage" before generating all choices' error.
# Append empty content, so at least one choice is generated.
await consumer.append_content("")
if usage_metadata:
await set_usage(
usage_metadata,
consumer,
self.deployment,
is_grounding_added,
)
@override
async def truncate_prompt(
self, prompt: GeminiGenAIPrompt, max_prompt_tokens: int
) -> TruncatedPrompt[GeminiGenAIPrompt]:
return await prompt.truncate(
tokenizer=self.count_prompt_tokens, user_limit=max_prompt_tokens
)
@override
async def count_prompt_tokens(self, prompt: GeminiGenAIPrompt) -> int:
with Timer("count_tokens[prompt] timing: {time}", log.debug):
resp = await self.client.aio.models.count_tokens(
model=self.model_id,
contents=list(prompt.contents),
)
log.debug(f"count_tokens[prompt] response: {json_dumps(resp)}")
if resp.total_tokens is None:
raise RuntimeServerError("Failed to count tokens for prompt")
return resp.total_tokens
@override
async def count_completion_tokens(self, string: str) -> int:
with Timer("count_tokens[completion] timing: {time}", log.debug):
resp = await self.client.aio.models.count_tokens(
model=self.model_id,
contents=string,
)
log.debug(f"count_tokens[completion] response: {json_dumps(resp)}")
if resp.total_tokens is None:
raise RuntimeServerError(
"Failed to count tokens for completion"
)
return resp.total_tokens
@override
async def chat(
self,
params: ModelParameters,
consumer: Consumer,
prompt: GeminiGenAIPrompt,
) -> None:
with Timer("predict timing: {time}", log.debug):
if log.isEnabledFor(DEBUG):
log.debug(
"predict request: "
+ json_dumps_short({"parameters": params, "prompt": prompt})
)
completion = ""
async for content in generate_with_retries(
lambda: self.process_chunks(
consumer,
prompt.tools,
lambda: self.send_message_async(params, prompt),
),
2,
):
completion += content
log.debug(f"predict response: {completion!r}")