aidial_adapter_vertexai/adapters.py (90 lines of code) (raw):

from typing import assert_never from google.genai.client import Client as GenAIClient from aidial_adapter_vertexai.chat.bison.adapter import ( BisonChatAdapter, BisonCodeChatAdapter, ) from aidial_adapter_vertexai.chat.chat_completion_adapter import ( ChatCompletionAdapter, ) from aidial_adapter_vertexai.chat.gemini.adapter import ( GeminiChatCompletionAdapter, GeminiGenAIChatCompletionAdapter, ) from aidial_adapter_vertexai.chat.imagen.adapter import ( ImagenChatCompletionAdapter, ) from aidial_adapter_vertexai.deployments import ( ChatCompletionDeployment, EmbeddingsDeployment, ) from aidial_adapter_vertexai.dial_api.storage import create_file_storage from aidial_adapter_vertexai.embedding.embeddings_adapter import ( EmbeddingsAdapter, ) from aidial_adapter_vertexai.embedding.multi_modal import ( MultiModalEmbeddingsAdapter, ) from aidial_adapter_vertexai.embedding.text import TextEmbeddingsAdapter async def get_chat_completion_model( api_key: str, deployment: ChatCompletionDeployment, client: GenAIClient ) -> ChatCompletionAdapter: model_id = deployment.get_model_id() match deployment: case ( ChatCompletionDeployment.CHAT_BISON_1 | ChatCompletionDeployment.CHAT_BISON_2 | ChatCompletionDeployment.CHAT_BISON_2_32K ): return await BisonChatAdapter.create(model_id) case ( ChatCompletionDeployment.CODECHAT_BISON_1 | ChatCompletionDeployment.CODECHAT_BISON_2 | ChatCompletionDeployment.CODECHAT_BISON_2_32K ): return await BisonCodeChatAdapter.create(model_id) case ( ChatCompletionDeployment.GEMINI_PRO_1 | ChatCompletionDeployment.GEMINI_PRO_VISION_1 | ChatCompletionDeployment.GEMINI_PRO_1_5_PREVIEW | ChatCompletionDeployment.GEMINI_PRO_1_5_V1 | ChatCompletionDeployment.GEMINI_PRO_1_5_V2 | ChatCompletionDeployment.GEMINI_FLASH_1_5_V1 | ChatCompletionDeployment.GEMINI_FLASH_1_5_V2 ): storage = create_file_storage(api_key) return await GeminiChatCompletionAdapter.create( storage, model_id, deployment ) case ( ChatCompletionDeployment.GEMINI_2_0_FLASH_EXP | ChatCompletionDeployment.GEMINI_2_0_FLASH_THINKING_EXP_1219 | ChatCompletionDeployment.GEMINI_2_0_EXPERIMENTAL_1206 ): storage = create_file_storage(api_key) return GeminiGenAIChatCompletionAdapter( client, storage, model_id, deployment ) case ChatCompletionDeployment.IMAGEN_005: storage = create_file_storage(api_key) return await ImagenChatCompletionAdapter.create(storage, model_id) case _: assert_never(deployment) async def get_embeddings_model( api_key: str, deployment: EmbeddingsDeployment ) -> EmbeddingsAdapter: model_id = deployment.get_model_id() match deployment: case ( EmbeddingsDeployment.TEXT_EMBEDDING_GECKO_1 | EmbeddingsDeployment.TEXT_EMBEDDING_GECKO_3 | EmbeddingsDeployment.TEXT_EMBEDDING_4 | EmbeddingsDeployment.TEXT_EMBEDDING_GECKO_MULTILINGUAL_1 | EmbeddingsDeployment.TEXT_MULTILINGUAL_EMBEDDING_2 ): return await TextEmbeddingsAdapter.create(model_id) case EmbeddingsDeployment.MULTI_MODAL_EMBEDDING_1: storage = create_file_storage(api_key) return await MultiModalEmbeddingsAdapter.create(storage, model_id) case _: assert_never(deployment)