aidial_adapter_vertexai/chat/bison/base.py (94 lines of code) (raw):
from abc import abstractmethod
from typing import AsyncIterator, List
from aidial_sdk.chat_completion import FinishReason, Message
from typing_extensions import override
from vertexai.preview.language_models import (
ChatModel,
CodeChatModel,
CountTokensResponse,
)
from aidial_adapter_vertexai.chat.bison.prompt import BisonPrompt
from aidial_adapter_vertexai.chat.chat_completion_adapter import (
ChatCompletionAdapter,
)
from aidial_adapter_vertexai.chat.consumer import Consumer
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.chat.truncate_prompt import TruncatedPrompt
from aidial_adapter_vertexai.dial_api.request import ModelParameters
from aidial_adapter_vertexai.dial_api.token_usage import TokenUsage
from aidial_adapter_vertexai.utils.log_config import vertex_ai_logger as log
from aidial_adapter_vertexai.utils.timer import Timer
BisonChatModel = ChatModel | CodeChatModel
class BisonChatCompletionAdapter(ChatCompletionAdapter[BisonPrompt]):
def __init__(self, model: BisonChatModel):
self.model = model
@abstractmethod
def send_message_async(
self, params: ModelParameters, prompt: BisonPrompt
) -> AsyncIterator[str]:
pass
@override
async def parse_prompt(
self,
tools: ToolsConfig,
static_tools: StaticToolsConfig,
messages: List[Message],
) -> BisonPrompt:
tools.not_supported()
static_tools.not_supported()
return BisonPrompt.parse(messages)
@override
async def truncate_prompt(
self, prompt: BisonPrompt, max_prompt_tokens: int
) -> TruncatedPrompt[BisonPrompt]:
return await prompt.truncate(
tokenizer=self.count_prompt_tokens, user_limit=max_prompt_tokens
)
@override
async def chat(
self, params: ModelParameters, consumer: Consumer, prompt: BisonPrompt
) -> None:
prompt_tokens = await self.count_prompt_tokens(prompt)
with Timer("predict timing: {time}", log.debug):
log.debug(
"predict request: "
f"parameters=({params}), "
f"prompt=({prompt})"
)
completion = ""
async for chunk in self.send_message_async(params, prompt):
completion += chunk
await consumer.append_content(chunk)
log.debug(f"predict response: {completion!r}")
completion_tokens = await self.count_completion_tokens(completion)
# PaLM models do not return finish reason.
# Use the heuristic to estimate it.
if completion_tokens == params.max_tokens:
await consumer.set_finish_reason(FinishReason.LENGTH)
await consumer.set_usage(
TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
)
@override
async def count_prompt_tokens(self, prompt: BisonPrompt) -> int:
chat_session = self.model.start_chat(
context=prompt.system_instruction,
message_history=prompt.history,
)
with Timer("count_tokens[prompt] timing: {time}", log.debug):
resp = chat_session.count_tokens(message=prompt.last_user_message)
log.debug(
f"count_tokens[prompt] response: {_display_token_count(resp)}"
)
return resp.total_tokens
@override
async def count_completion_tokens(self, string: str) -> int:
with Timer("count_tokens[completion] timing: {time}", log.debug):
resp = self.model.start_chat().count_tokens(message=string)
log.debug(
f"count_tokens[completion] response: {_display_token_count(resp)}"
)
return resp.total_tokens
def _display_token_count(response: CountTokensResponse) -> str:
return f"tokens: {response.total_tokens}, billable characters: {response.total_billable_characters}"