aidial_adapter_bedrock/embedding/cohere/embed_text.py (72 lines of code) (raw):

""" Text Embeddings Adapter for Cohere Embed model See the documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html https://docs.cohere.com/reference/embed """ from typing import AsyncIterator, List, Self from aidial_sdk.embeddings import Response as EmbeddingsResponse from aidial_sdk.embeddings.request import EmbeddingsRequest from aidial_adapter_bedrock.bedrock import Bedrock from aidial_adapter_bedrock.dial_api.embedding_inputs import ( EMPTY_INPUT_LIST_ERROR, collect_embedding_inputs_without_attachments, ) from aidial_adapter_bedrock.dial_api.response import make_embeddings_response from aidial_adapter_bedrock.embedding.cohere.response import ( call_embedding_model, ) from aidial_adapter_bedrock.embedding.embeddings_adapter import ( EmbeddingsAdapter, ) from aidial_adapter_bedrock.embedding.validation import ( validate_embeddings_request, ) from aidial_adapter_bedrock.llm.errors import ValidationError from aidial_adapter_bedrock.utils.json import remove_nones def create_cohere_request(texts: List[str], input_type: str) -> dict: return remove_nones( { "texts": texts, "input_type": input_type, } ) def get_text_inputs(request: EmbeddingsRequest) -> AsyncIterator[str]: async def on_texts(texts: List[str]) -> str: if len(texts) == 0: raise EMPTY_INPUT_LIST_ERROR elif len(texts) == 1: return texts[0] else: raise ValidationError( "No more than one element is allowed in an element of custom_input list" ) return collect_embedding_inputs_without_attachments( request, on_texts=on_texts ) class CohereTextEmbeddings(EmbeddingsAdapter): model: str client: Bedrock @classmethod def create(cls, client: Bedrock, model: str) -> Self: return cls(client=client, model=model) async def embeddings( self, request: EmbeddingsRequest ) -> EmbeddingsResponse: validate_embeddings_request( request, supports_type=True, supports_dimensions=False, ) input_type: str | None = ( request.custom_fields and request.custom_fields.type ) if input_type is None: raise ValidationError( "Embedding type request parameter is required" ) text_inputs = [txt async for txt in get_text_inputs(request)] embedding_request = create_cohere_request(text_inputs, input_type) embeddings, tokens = await call_embedding_model( self.client, self.model, embedding_request ) return make_embeddings_response( model=self.model, encoding_format=request.encoding_format, vectors=embeddings, prompt_tokens=tokens, )