aidial_adapter_bedrock/llm/model/adapter.py (196 lines of code) (raw):

from typing import assert_never from aidial_adapter_bedrock.adapter_deployments import ( AdapterChatCompletionDeployment, AdapterEmbeddingsDeployment, ) from aidial_adapter_bedrock.aws_client_config import AWSClientConfig from aidial_adapter_bedrock.bedrock import Bedrock from aidial_adapter_bedrock.deployments import ( ChatCompletionDeployment, EmbeddingsDeployment, ) from aidial_adapter_bedrock.embedding.amazon.titan_image import ( AmazonTitanImageEmbeddings, ) from aidial_adapter_bedrock.embedding.amazon.titan_text import ( AmazonTitanTextEmbeddings, ) from aidial_adapter_bedrock.embedding.cohere.embed_text import ( CohereTextEmbeddings, ) from aidial_adapter_bedrock.embedding.embeddings_adapter import ( EmbeddingsAdapter, ) from aidial_adapter_bedrock.llm.chat_model import ChatCompletionAdapter from aidial_adapter_bedrock.llm.converse.factory import ( ConverseAdapterFactory, ToolsSupport, ) from aidial_adapter_bedrock.llm.converse.types import ( ConverseDocumentType, ConverseImageType, ) from aidial_adapter_bedrock.llm.model.ai21 import AI21Adapter from aidial_adapter_bedrock.llm.model.amazon import AmazonAdapter from aidial_adapter_bedrock.llm.model.claude.v1_v2.adapter import ( Adapter as Claude_V1_V2, ) from aidial_adapter_bedrock.llm.model.claude.v3.adapter import ( Adapter as Claude_V3, ) from aidial_adapter_bedrock.llm.model.cohere import CohereAdapter from aidial_adapter_bedrock.llm.model.stability.v1 import StabilityV1Adapter from aidial_adapter_bedrock.llm.model.stability.v2 import StabilityV2Adapter async def get_bedrock_adapter( *, deployment: AdapterChatCompletionDeployment, api_key: str, aws_client_config: AWSClientConfig, ) -> ChatCompletionAdapter: model = deployment.upstream_deployment_id converse_adapter = ConverseAdapterFactory( deployment=model, aws_client_config=aws_client_config, api_key=api_key ) match deployment.reference_deployment_id: case ( ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_SONNET | ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_SONNET_US | ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_SONNET_EU | ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET | ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_US | ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_EU | ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_V2 | ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_SONNET_V2_US | ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_HAIKU | ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_HAIKU | ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_5_HAIKU_US | ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_HAIKU_US | ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_HAIKU_EU | ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_OPUS | ChatCompletionDeployment.ANTHROPIC_CLAUDE_V3_OPUS_US ): return Claude_V3.create( deployment.clone(deployment.reference_deployment_id), api_key, aws_client_config, ) case ( ChatCompletionDeployment.ANTHROPIC_CLAUDE_INSTANT_V1 | ChatCompletionDeployment.ANTHROPIC_CLAUDE_V2 | ChatCompletionDeployment.ANTHROPIC_CLAUDE_V2_1 ): return await Claude_V1_V2.create( await Bedrock.acreate(aws_client_config), model ) case ( ChatCompletionDeployment.AI21_J2_JUMBO_INSTRUCT | ChatCompletionDeployment.AI21_J2_GRANDE_INSTRUCT | ChatCompletionDeployment.AI21_J2_MID_V1 | ChatCompletionDeployment.AI21_J2_ULTRA_V1 ): return AI21Adapter.create( await Bedrock.acreate(aws_client_config), model ) case ( ChatCompletionDeployment.STABILITY_STABLE_DIFFUSION_XL | ChatCompletionDeployment.STABILITY_STABLE_DIFFUSION_XL_V1 ): return StabilityV1Adapter.create( await Bedrock.acreate(aws_client_config), model, api_key ) case ( ChatCompletionDeployment.STABILITY_STABLE_IMAGE_CORE_V1 | ChatCompletionDeployment.STABILITY_STABLE_IMAGE_ULTRA_V1 ): return StabilityV2Adapter.create( await Bedrock.acreate(aws_client_config), model, api_key, image_to_image_supported=False, ) case ChatCompletionDeployment.STABILITY_STABLE_DIFFUSION_3_LARGE_V1: return StabilityV2Adapter.create( await Bedrock.acreate(aws_client_config), model, api_key, image_to_image_supported=True, image_width_constraints=(640, 1536), image_height_constraints=(640, 1536), ) case ChatCompletionDeployment.AMAZON_TITAN_TG1_LARGE: return AmazonAdapter.create( await Bedrock.acreate(aws_client_config), model ) case ( ChatCompletionDeployment.COHERE_COMMAND_TEXT_V14 | ChatCompletionDeployment.COHERE_COMMAND_LIGHT_TEXT_V14 ): return CohereAdapter.create( await Bedrock.acreate(aws_client_config), model ) case ( ChatCompletionDeployment.AMAZON_NOVA_MICRO | ChatCompletionDeployment.AMAZON_NOVA_PRO | ChatCompletionDeployment.AMAZON_NOVA_LITE ): return await converse_adapter.create( tools_support=ToolsSupport.ALWAYS, supported_image_types=ConverseImageType.all(), supported_document_types=ConverseDocumentType.all(), ) case ( ChatCompletionDeployment.META_LLAMA3_8B_INSTRUCT_V1 | ChatCompletionDeployment.META_LLAMA3_70B_INSTRUCT_V1 ): return await converse_adapter.create( supported_image_types=ConverseImageType.all(), ) case ( ChatCompletionDeployment.META_LLAMA3_1_8B_INSTRUCT_V1 | ChatCompletionDeployment.META_LLAMA3_2_1B_INSTRUCT_V1 | ChatCompletionDeployment.META_LLAMA3_2_3B_INSTRUCT_V1 ): return await converse_adapter.create() case ( ChatCompletionDeployment.META_LLAMA3_1_70B_INSTRUCT_V1 | ChatCompletionDeployment.META_LLAMA3_1_405B_INSTRUCT_V1 ): return await converse_adapter.create( tools_support=ToolsSupport.NON_STREAMING_ONLY, ) case ( ChatCompletionDeployment.META_LLAMA3_2_11B_INSTRUCT_V1 | ChatCompletionDeployment.META_LLAMA3_2_90B_INSTRUCT_V1 ): return await converse_adapter.create( tools_support=ToolsSupport.NON_STREAMING_ONLY, supported_image_types=ConverseImageType.all(), ) case _: assert_never(deployment) async def get_embeddings_model( *, deployment: AdapterEmbeddingsDeployment, api_key: str, aws_client_config: AWSClientConfig, ) -> EmbeddingsAdapter: model = deployment.upstream_deployment_id client = await Bedrock.acreate(aws_client_config) match deployment.reference_deployment_id: case EmbeddingsDeployment.AMAZON_TITAN_EMBED_TEXT_V1: return AmazonTitanTextEmbeddings.create( client, model, supports_dimensions=False ) case EmbeddingsDeployment.AMAZON_TITAN_EMBED_TEXT_V2: return AmazonTitanTextEmbeddings.create( client, model, supports_dimensions=True ) case EmbeddingsDeployment.AMAZON_TITAN_EMBED_IMAGE_V1: return AmazonTitanImageEmbeddings.create(client, model, api_key) case ( EmbeddingsDeployment.COHERE_EMBED_ENGLISH_V3 | EmbeddingsDeployment.COHERE_EMBED_MULTILINGUAL_V3 ): return CohereTextEmbeddings.create(client, model) case _: assert_never(deployment)