aidial_adapter_bedrock/chat_completion.py (155 lines of code) (raw):
import asyncio
from typing import List, Optional, assert_never
from aidial_sdk.chat_completion import ChatCompletion, Request, Response
from aidial_sdk.chat_completion.request import ChatCompletionRequest
from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin
from aidial_sdk.deployment.tokenize import (
TokenizeError,
TokenizeInputRequest,
TokenizeInputString,
TokenizeOutput,
TokenizeRequest,
TokenizeResponse,
TokenizeSuccess,
)
from aidial_sdk.deployment.truncate_prompt import (
TruncatePromptError,
TruncatePromptRequest,
TruncatePromptResponse,
TruncatePromptResult,
TruncatePromptSuccess,
)
from typing_extensions import override
from aidial_adapter_bedrock.adapter_deployments import (
AdapterChatCompletionDeployment,
)
from aidial_adapter_bedrock.aws_client_config import AWSClientConfigFactory
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_model import (
ChatCompletionAdapter,
TextCompletionAdapter,
)
from aidial_adapter_bedrock.llm.consumer import ChoiceConsumer
from aidial_adapter_bedrock.llm.errors import UserError, ValidationError
from aidial_adapter_bedrock.llm.model.adapter import get_bedrock_adapter
from aidial_adapter_bedrock.llm.truncate_prompt import DiscardedMessages
from aidial_adapter_bedrock.server.exceptions import (
dial_exception_decorator,
not_implemented_handler,
)
from aidial_adapter_bedrock.utils.log_config import app_logger as log
class BedrockChatCompletion(ChatCompletion):
deployment: AdapterChatCompletionDeployment
def __init__(self, deployment: AdapterChatCompletionDeployment) -> None:
self.deployment = deployment
async def _get_model(
self, request: FromRequestDeploymentMixin
) -> ChatCompletionAdapter:
aws_client_config = await AWSClientConfigFactory(
request=request,
).get_client_config()
return await get_bedrock_adapter(
deployment=self.deployment,
api_key=request.api_key,
aws_client_config=aws_client_config,
)
@dial_exception_decorator
async def chat_completion(self, request: Request, response: Response):
response.set_model(self.deployment.upstream_deployment_id)
model = await self._get_model(request)
params = ModelParameters.create(request)
discarded_messages: Optional[DiscardedMessages] = None
async def generate_response(usage: TokenUsage) -> None:
nonlocal discarded_messages
with ChoiceConsumer(response=response) as consumer:
if isinstance(model, TextCompletionAdapter):
consumer.set_tools_emulator(
model.tools_emulator(params.tool_config)
)
try:
await model.chat(consumer, params, request.messages)
except UserError as e:
await e.report_usage(consumer.choice)
await response.aflush()
raise e
usage.accumulate(consumer.usage)
discarded_messages = consumer.discarded_messages
usage = TokenUsage()
await asyncio.gather(
*(generate_response(usage) for _ in range(request.n or 1))
)
log.debug(f"usage: {usage}")
response.set_usage(usage.prompt_tokens, usage.completion_tokens)
if discarded_messages is not None:
response.set_discarded_messages(discarded_messages)
@override
@dial_exception_decorator
@not_implemented_handler
async def tokenize(self, request: TokenizeRequest) -> TokenizeResponse:
model = await self._get_model(request)
outputs: List[TokenizeOutput] = []
for input in request.inputs:
match input:
case TokenizeInputRequest():
outputs.append(
await self._tokenize_request(model, input.value)
)
case TokenizeInputString():
outputs.append(
await self._tokenize_string(model, input.value)
)
case _:
assert_never(input.type)
return TokenizeResponse(outputs=outputs)
async def _tokenize_string(
self, model: ChatCompletionAdapter, value: str
) -> TokenizeOutput:
try:
tokens = await model.count_completion_tokens(value)
return TokenizeSuccess(token_count=tokens)
except NotImplementedError:
raise
except Exception as e:
return TokenizeError(error=str(e))
async def _tokenize_request(
self, model: ChatCompletionAdapter, request: ChatCompletionRequest
) -> TokenizeOutput:
params = ModelParameters.create(request)
try:
token_count = await model.count_prompt_tokens(
params, request.messages
)
return TokenizeSuccess(token_count=token_count)
except NotImplementedError:
raise
except Exception as e:
return TokenizeError(error=str(e))
@override
@dial_exception_decorator
@not_implemented_handler
async def truncate_prompt(
self, request: TruncatePromptRequest
) -> TruncatePromptResponse:
model = await self._get_model(request)
outputs: List[TruncatePromptResult] = []
for input in request.inputs:
outputs.append(await self._truncate_prompt_request(model, input))
return TruncatePromptResponse(outputs=outputs)
async def _truncate_prompt_request(
self, model: ChatCompletionAdapter, request: ChatCompletionRequest
) -> TruncatePromptResult:
try:
params = ModelParameters.create(request)
if params.max_prompt_tokens is None:
raise ValidationError("max_prompt_tokens is required")
discarded_messages = await model.compute_discarded_messages(
params, request.messages
)
return TruncatePromptSuccess(
discarded_messages=discarded_messages or []
)
except NotImplementedError:
raise
except Exception as e:
return TruncatePromptError(error=str(e))