aidial_adapter_bedrock/adapter_deployments.py (106 lines of code) (raw):

import json from enum import Enum from typing import Dict, Generic, Iterable, Self, Tuple, TypeVar from pydantic import BaseModel from aidial_adapter_bedrock.deployments import ( CHAT_COMPLETION_REDIRECTS, ChatCompletionDeployment, EmbeddingsDeployment, ) from aidial_adapter_bedrock.utils.log_config import app_logger as log _D = TypeVar("_D", bound=Enum) _T = TypeVar("_T", bound=Enum) class AdapterDeployment(BaseModel, Generic[_D]): adapter_deployment_id: str """ The deployment id under which the model is served by the Adapter at the route /openai/deployments/{deployment_id}/(chat/completions|embeddings) """ upstream_deployment_id: str """ The deployment id of the corresponding Bedrock model. The upstream request to the Bedrock service will use this deployment id. """ reference_deployment_id: _D """ The reference Bedrock deployment which is known to share the same API as `upstream_deployment_id`. """ @classmethod def supported( cls, *, deployment_id: str | None = None, upstream: _D ) -> Self: return cls( adapter_deployment_id=deployment_id or upstream.value, upstream_deployment_id=upstream.value, reference_deployment_id=upstream, ) def compat(self, deployment_id: str) -> "AdapterDeployment[_D]": return AdapterDeployment( adapter_deployment_id=deployment_id, upstream_deployment_id=deployment_id, reference_deployment_id=self.reference_deployment_id, ) def clone(self, reference_deployment_id: _T) -> "AdapterDeployment[_T]": return AdapterDeployment( adapter_deployment_id=self.adapter_deployment_id, upstream_deployment_id=self.upstream_deployment_id, reference_deployment_id=reference_deployment_id, ) AdapterChatCompletionDeployment = AdapterDeployment[ChatCompletionDeployment] AdapterEmbeddingsDeployment = AdapterDeployment[EmbeddingsDeployment] class AdapterDeployments(BaseModel): chat_completions: Dict[str, AdapterChatCompletionDeployment] embeddings: Dict[str, AdapterEmbeddingsDeployment] @classmethod def create(cls, *, compat_mapping: Dict[str, str]) -> "AdapterDeployments": chat_completions = {e.value for e in ChatCompletionDeployment} embeddings = {e.value for e in EmbeddingsDeployment} for deployment_id, supported_id in compat_mapping.items(): if deployment_id in chat_completions or deployment_id in embeddings: log.warning( f"{deployment_id!r} is one of the Bedrock deployments supported by the adapter already. " f"Remove {deployment_id!r} from the compatibility mapping to avoid the warning." ) if ( deployment_id in chat_completions and supported_id in embeddings ): raise ValueError( f"The chat completion deployment {deployment_id!r} is mapped onto the embeddings deployment {supported_id!r}" ) if ( deployment_id in embeddings and supported_id in chat_completions ): raise ValueError( f"The embeddings deployment {deployment_id!r} is mapped onto the chat completion deployment {supported_id!r}" ) compat_mapping, chat_completions = _create_deployments( compat_mapping, ChatCompletionDeployment, redirects=CHAT_COMPLETION_REDIRECTS, ) compat_mapping, embeddings = _create_deployments( compat_mapping, EmbeddingsDeployment ) if compat_mapping: raise ValueError( f"None of the values in the following compatibility mapping corresponds to a Bedrock deployment supported by the adapter: {json.dumps(compat_mapping)}. " f"Remap the deployments to the supported Bedrock deployments to fix the error." ) ret = cls(chat_completions=chat_completions, embeddings=embeddings) log.debug(f"Adapter deployments: {ret.json()}") return ret def _create_deployments( compat_mapping: Dict[str, str], upstream_deployments: Iterable[_D], *, redirects: Dict[_D, _D] | None = None, ) -> Tuple[Dict[str, str], Dict[str, AdapterDeployment[_D]]]: compat_mapping = compat_mapping.copy() redirects = redirects or {} supported: Dict[str, AdapterDeployment[_D]] = {} for upstream in upstream_deployments: deployment_id = upstream.value supported[deployment_id] = AdapterDeployment.supported( deployment_id=deployment_id, upstream=redirects.get(upstream, upstream), ) compat: Dict[str, AdapterDeployment[_D]] = {} for deployment_id, supported_deployment_id in list(compat_mapping.items()): if ( supported_deployment := supported.get(supported_deployment_id) ) is None: continue compat_mapping.pop(deployment_id) compat[deployment_id] = supported_deployment.compat(deployment_id) return compat_mapping, supported | compat