aidial_adapter_bedrock/aws_client_config.py (75 lines of code) (raw):

import os import boto3 from aidial_sdk.embeddings import Request from pydantic import BaseModel, Field from aidial_adapter_bedrock.utils.concurrency import make_async from aidial_adapter_bedrock.utils.env import get_aws_default_region from aidial_adapter_bedrock.utils.json import remove_nones class AWSClientCredentials(BaseModel): aws_access_key_id: str aws_secret_access_key: str aws_session_token: str | None = None class AWSClientConfig(BaseModel): region: str credentials: AWSClientCredentials | None = None def get_boto_client_kwargs(self) -> dict: client_kwargs = {"region_name": self.region} if self.credentials: client_kwargs.update(self.credentials.dict(exclude_none=True)) return client_kwargs def get_anthropic_bedrock_client_kwargs(self) -> dict: client_kwargs = {"aws_region": self.region} if self.credentials: credentials = remove_nones( { "aws_access_key": self.credentials.aws_access_key_id, "aws_secret_key": self.credentials.aws_secret_access_key, "aws_session_token": self.credentials.aws_session_token, } ) client_kwargs.update(credentials) return client_kwargs class UpstreamConfig(BaseModel): region: str = Field(default_factory=get_aws_default_region) aws_access_key_id: str | None = None aws_secret_access_key: str | None = None aws_assume_role_arn: str | None = os.environ.get("AWS_ASSUME_ROLE_ARN") class AWSClientConfigFactory: UPSTREAM_CONFIG_HEADER_NAME = "x-upstream-extra-data" BEDROCK_ACCESS_SESSION_NAME = "BedrockAccessSession" def __init__(self, request): self.upstream_config = self._get_upstream_config(request) async def get_client_config(self) -> AWSClientConfig: return AWSClientConfig( region=self.upstream_config.region, credentials=await self._get_client_credentials(), ) def _get_upstream_config(self, request: Request) -> UpstreamConfig: conf = request.headers.get(self.UPSTREAM_CONFIG_HEADER_NAME) return UpstreamConfig.parse_raw(conf) if conf else UpstreamConfig() async def _get_client_credentials(self) -> AWSClientCredentials | None: key_id = self.upstream_config.aws_access_key_id secret_access_key = self.upstream_config.aws_secret_access_key if key_id and secret_access_key: return AWSClientCredentials( aws_access_key_id=key_id, aws_secret_access_key=secret_access_key, ) if self.upstream_config.aws_assume_role_arn: return await self._get_assumed_role_tmp_credentials() async def _get_assumed_role_tmp_credentials(self) -> AWSClientCredentials: sts_client = await make_async( lambda: boto3.Session().client( "sts", region_name=self.upstream_config.region ) ) assumed_role_object = sts_client.assume_role( RoleArn=self.upstream_config.aws_assume_role_arn, RoleSessionName=self.BEDROCK_ACCESS_SESSION_NAME, ) assumed_role_credentials = assumed_role_object["Credentials"] return AWSClientCredentials( aws_access_key_id=assumed_role_credentials["AccessKeyId"], aws_secret_access_key=assumed_role_credentials["SecretAccessKey"], aws_session_token=assumed_role_credentials["SessionToken"], )