aidial_adapter_bedrock/embedding/amazon/titan_text.py (72 lines of code) (raw):

""" Amazing Titan Text Embeddings Adapter See official cookbook for usage instructions: https://github.com/aws-samples/amazon-bedrock-samples/blob/5752afb78e7fab49cfd42d38bb09d40756bf0ea0/multimodal/Titan/embeddings/v2/Titan-V2-Embeddings.ipynb """ import asyncio from typing import AsyncIterator, List, Self, Tuple from aidial_sdk.embeddings import Response as EmbeddingsResponse from aidial_sdk.embeddings.request import EmbeddingsRequest from aidial_adapter_bedrock.bedrock import Bedrock from aidial_adapter_bedrock.dial_api.embedding_inputs import ( EMPTY_INPUT_LIST_ERROR, collect_embedding_inputs_without_attachments, ) from aidial_adapter_bedrock.dial_api.response import make_embeddings_response from aidial_adapter_bedrock.embedding.amazon.response import ( call_embedding_model, ) from aidial_adapter_bedrock.embedding.embeddings_adapter import ( EmbeddingsAdapter, ) from aidial_adapter_bedrock.embedding.validation import ( validate_embeddings_request, ) from aidial_adapter_bedrock.llm.errors import ValidationError from aidial_adapter_bedrock.utils.json import remove_nones def create_titan_request(input: str, dimensions: int | None) -> dict: return remove_nones({"inputText": input, "dimensions": dimensions}) def get_text_inputs(request: EmbeddingsRequest) -> AsyncIterator[str]: async def on_texts(texts: List[str]) -> str: if len(texts) == 0: raise EMPTY_INPUT_LIST_ERROR elif len(texts) == 1: return texts[0] else: raise ValidationError( "No more than one element is allowed in an element of custom_input list" ) return collect_embedding_inputs_without_attachments( request, on_texts=on_texts ) class AmazonTitanTextEmbeddings(EmbeddingsAdapter): model: str client: Bedrock supports_dimensions: bool @classmethod def create( cls, client: Bedrock, model: str, supports_dimensions: bool ) -> Self: return cls( client=client, model=model, supports_dimensions=supports_dimensions ) async def embeddings( self, request: EmbeddingsRequest ) -> EmbeddingsResponse: validate_embeddings_request( request, supports_type=False, supports_dimensions=self.supports_dimensions, ) async def compute_embeddings(req: str) -> Tuple[List[float], int]: return await call_embedding_model( self.client, self.model, create_titan_request(req, request.dimensions), ) # NOTE: Amazon Titan doesn't support batched inputs tasks = [ asyncio.create_task(compute_embeddings(req)) async for req in get_text_inputs(request) ] results = await asyncio.gather(*tasks) return make_embeddings_response( model=self.model, encoding_format=request.encoding_format, vectors=[r[0] for r in results], prompt_tokens=sum(r[1] for r in results), )