aidial_adapter_vertexai/chat_completion.py (164 lines of code) (raw):
import asyncio
from typing import List, assert_never
from aidial_sdk.chat_completion import ChatCompletion, Request, Response
from aidial_sdk.chat_completion.request import ChatCompletionRequest
from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin
from aidial_sdk.deployment.tokenize import (
TokenizeError,
TokenizeInputRequest,
TokenizeInputString,
TokenizeOutput,
TokenizeRequest,
TokenizeResponse,
TokenizeSuccess,
)
from aidial_sdk.deployment.truncate_prompt import (
TruncatePromptError,
TruncatePromptRequest,
TruncatePromptResponse,
TruncatePromptResult,
TruncatePromptSuccess,
)
from aidial_sdk.exceptions import ResourceNotFoundError
from google.genai.client import Client as GenAIClient
from typing_extensions import override
from aidial_adapter_vertexai.adapters import get_chat_completion_model
from aidial_adapter_vertexai.chat.chat_completion_adapter import (
ChatCompletionAdapter,
TruncatedPrompt,
)
from aidial_adapter_vertexai.chat.consumer import ChoiceConsumer
from aidial_adapter_vertexai.chat.errors import UserError, ValidationError
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.deployments import ChatCompletionDeployment
from aidial_adapter_vertexai.dial_api.exceptions import dial_exception_decorator
from aidial_adapter_vertexai.dial_api.request import ModelParameters
from aidial_adapter_vertexai.dial_api.token_usage import TokenUsage
from aidial_adapter_vertexai.utils.log_config import app_logger as log
from aidial_adapter_vertexai.utils.not_implemented import is_implemented
class VertexAIChatCompletion(ChatCompletion):
def __init__(self, client: GenAIClient):
self.client = client
async def _get_model(
self, request: FromRequestDeploymentMixin
) -> ChatCompletionAdapter:
return await get_chat_completion_model(
deployment=ChatCompletionDeployment(request.deployment_id),
api_key=request.api_key,
client=self.client,
)
@dial_exception_decorator
async def chat_completion(self, request: Request, response: Response):
response.set_model(request.deployment_id)
model = await self._get_model(request)
tools = ToolsConfig.from_request(request)
static_tools = StaticToolsConfig.from_request(request)
prompt = await model.parse_prompt(tools, static_tools, request.messages)
if isinstance(prompt, UserError):
await prompt.report_usage(response)
raise prompt
params = ModelParameters.create(request)
# Currently n>1 is emulated by calling the model n times
n = params.n or 1
params.n = None
if params.max_prompt_tokens is None:
truncated_prompt = TruncatedPrompt(
prompt=prompt, discarded_messages=[]
)
else:
if not is_implemented(model.truncate_prompt):
raise ValidationError(
"max_prompt_tokens request parameter is not supported"
)
truncated_prompt = await model.truncate_prompt(
prompt, params.max_prompt_tokens
)
async def generate_response(usage: TokenUsage) -> None:
with ChoiceConsumer(response=response) as consumer:
await model.chat(params, consumer, truncated_prompt.prompt)
usage.accumulate(consumer.usage)
log.debug(
f"finish_reason[{consumer.choice_idx}]: {consumer.finish_reason}"
)
usage = TokenUsage()
await asyncio.gather(*(generate_response(usage) for _ in range(n)))
log.debug(f"usage: {usage}")
response.set_usage(usage.prompt_tokens, usage.completion_tokens)
if params.max_prompt_tokens is not None:
response.set_discarded_messages(truncated_prompt.discarded_messages)
@override
@dial_exception_decorator
async def tokenize(self, request: TokenizeRequest) -> TokenizeResponse:
model = await self._get_model(request)
if not is_implemented(
model.count_completion_tokens
) or not is_implemented(model.count_prompt_tokens):
raise ResourceNotFoundError("The endpoint is not implemented")
outputs: List[TokenizeOutput] = []
for input in request.inputs:
match input:
case TokenizeInputRequest():
outputs.append(
await self._tokenize_request(model, input.value)
)
case TokenizeInputString():
outputs.append(
await self._tokenize_string(model, input.value)
)
case _:
assert_never(input.type)
return TokenizeResponse(outputs=outputs)
async def _tokenize_string(
self, model: ChatCompletionAdapter, value: str
) -> TokenizeOutput:
try:
tokens = await model.count_completion_tokens(value)
return TokenizeSuccess(token_count=tokens)
except Exception as e:
return TokenizeError(error=str(e))
async def _tokenize_request(
self, model: ChatCompletionAdapter, request: ChatCompletionRequest
) -> TokenizeOutput:
try:
tools = ToolsConfig.from_request(request)
static_tools = StaticToolsConfig.from_request(request)
prompt = await model.parse_prompt(
tools, static_tools, request.messages
)
if isinstance(prompt, UserError):
raise prompt
token_count = await model.count_prompt_tokens(prompt)
return TokenizeSuccess(token_count=token_count)
except Exception as e:
return TokenizeError(error=str(e))
@override
@dial_exception_decorator
async def truncate_prompt(
self, request: TruncatePromptRequest
) -> TruncatePromptResponse:
model = await self._get_model(request)
if not is_implemented(model.truncate_prompt):
raise ResourceNotFoundError("The endpoint is not implemented")
outputs: List[TruncatePromptResult] = []
for input in request.inputs:
outputs.append(await self._truncate_prompt_request(model, input))
return TruncatePromptResponse(outputs=outputs)
async def _truncate_prompt_request(
self, model: ChatCompletionAdapter, request: ChatCompletionRequest
) -> TruncatePromptResult:
try:
if request.max_prompt_tokens is None:
raise ValidationError("max_prompt_tokens is required")
tools = ToolsConfig.from_request(request)
static_tools = StaticToolsConfig.from_request(request)
prompt = await model.parse_prompt(
tools, static_tools, request.messages
)
truncated_prompt = await model.truncate_prompt(
prompt, request.max_prompt_tokens
)
return TruncatePromptSuccess(
discarded_messages=truncated_prompt.discarded_messages
)
except Exception as e:
return TruncatePromptError(error=str(e))