aidial_adapter_bedrock/bedrock.py (109 lines of code) (raw):
import json
from abc import ABC
from logging import DEBUG
from typing import Any, AsyncIterator, Mapping, Optional, Tuple, Unpack
import boto3
from botocore.eventstream import EventStream
from botocore.response import StreamingBody
from pydantic import BaseModel, Field
from aidial_adapter_bedrock.aws_client_config import AWSClientConfig
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.converse.types import ConverseRequest
from aidial_adapter_bedrock.utils.concurrency import (
make_async,
to_async_iterator,
)
from aidial_adapter_bedrock.utils.json import json_dumps_short
from aidial_adapter_bedrock.utils.log_config import bedrock_logger as log
Body = dict
Headers = Mapping[str, str]
class Bedrock:
client: Any
def __init__(self, client: Any):
self.client = client
@classmethod
async def acreate(cls, aws_client_config: AWSClientConfig) -> "Bedrock":
client_kwargs = aws_client_config.get_boto_client_kwargs()
client_kwargs["service_name"] = "bedrock-runtime"
client = await make_async(
lambda: boto3.Session().client(**client_kwargs)
)
return cls(client)
async def aconverse_non_streaming(
self, model: str, **params: Unpack[ConverseRequest]
):
response = await make_async(
lambda: self.client.converse(modelId=model, **params)
)
return response
async def aconverse_streaming(
self, model: str, **params: Unpack[ConverseRequest]
):
response = await make_async(
lambda: self.client.converse_stream(modelId=model, **params)
)
return to_async_iterator(iter(response["stream"]))
def _create_invoke_params(self, model: str, body: dict) -> dict:
return {
"modelId": model,
"body": json.dumps(body),
"accept": "application/json",
"contentType": "application/json",
}
async def ainvoke_non_streaming(
self, model: str, args: dict
) -> Tuple[Body, Headers]:
if log.isEnabledFor(DEBUG):
log.debug(
f"request: {json_dumps_short({'model': model, 'args': args})}"
)
params = self._create_invoke_params(model, args)
response = await make_async(lambda: self.client.invoke_model(**params))
if log.isEnabledFor(DEBUG):
log.debug(f"response: {json_dumps_short(response)}")
body: StreamingBody = response["body"]
body_dict = json.loads(await make_async(lambda: body.read()))
response_headers = response.get("ResponseMetadata", {}).get(
"HTTPHeaders", {}
)
if log.isEnabledFor(DEBUG):
log.debug(f"response['body']: {json_dumps_short(body_dict)}")
return body_dict, response_headers
async def ainvoke_streaming(
self, model: str, args: dict
) -> AsyncIterator[dict]:
if log.isEnabledFor(DEBUG):
log.debug(
f"request: {json_dumps_short({'model': model, 'args': args})}"
)
params = self._create_invoke_params(model, args)
response = await make_async(
lambda: self.client.invoke_model_with_response_stream(**params)
)
if log.isEnabledFor(DEBUG):
log.debug(f"response: {json_dumps_short(response)}")
body: EventStream = response["body"]
async for event in to_async_iterator(iter(body)):
chunk = event.get("chunk")
if chunk:
chunk_dict = json.loads(chunk.get("bytes").decode())
if log.isEnabledFor(DEBUG):
log.debug(f"chunk: {json_dumps_short(chunk_dict)}")
yield chunk_dict
class InvocationMetrics(BaseModel):
inputTokenCount: int
outputTokenCount: int
invocationLatency: int
firstByteLatency: int
class ResponseWithInvocationMetricsMixin(ABC, BaseModel):
invocation_metrics: Optional[InvocationMetrics] = Field(
alias="amazon-bedrock-invocationMetrics"
)
def usage_by_metrics(self) -> TokenUsage:
metrics = self.invocation_metrics
if metrics is None:
return TokenUsage()
return TokenUsage(
prompt_tokens=metrics.inputTokenCount,
completion_tokens=metrics.outputTokenCount,
)