aidial_adapter_bedrock/adapter_deployments.py (106 lines of code) (raw):
import json
from enum import Enum
from typing import Dict, Generic, Iterable, Self, Tuple, TypeVar
from pydantic import BaseModel
from aidial_adapter_bedrock.deployments import (
CHAT_COMPLETION_REDIRECTS,
ChatCompletionDeployment,
EmbeddingsDeployment,
)
from aidial_adapter_bedrock.utils.log_config import app_logger as log
_D = TypeVar("_D", bound=Enum)
_T = TypeVar("_T", bound=Enum)
class AdapterDeployment(BaseModel, Generic[_D]):
adapter_deployment_id: str
"""
The deployment id under which the model is served by the Adapter
at the route /openai/deployments/{deployment_id}/(chat/completions|embeddings)
"""
upstream_deployment_id: str
"""
The deployment id of the corresponding Bedrock model.
The upstream request to the Bedrock service will use this deployment id.
"""
reference_deployment_id: _D
"""
The reference Bedrock deployment which is known to share
the same API as `upstream_deployment_id`.
"""
@classmethod
def supported(
cls, *, deployment_id: str | None = None, upstream: _D
) -> Self:
return cls(
adapter_deployment_id=deployment_id or upstream.value,
upstream_deployment_id=upstream.value,
reference_deployment_id=upstream,
)
def compat(self, deployment_id: str) -> "AdapterDeployment[_D]":
return AdapterDeployment(
adapter_deployment_id=deployment_id,
upstream_deployment_id=deployment_id,
reference_deployment_id=self.reference_deployment_id,
)
def clone(self, reference_deployment_id: _T) -> "AdapterDeployment[_T]":
return AdapterDeployment(
adapter_deployment_id=self.adapter_deployment_id,
upstream_deployment_id=self.upstream_deployment_id,
reference_deployment_id=reference_deployment_id,
)
AdapterChatCompletionDeployment = AdapterDeployment[ChatCompletionDeployment]
AdapterEmbeddingsDeployment = AdapterDeployment[EmbeddingsDeployment]
class AdapterDeployments(BaseModel):
chat_completions: Dict[str, AdapterChatCompletionDeployment]
embeddings: Dict[str, AdapterEmbeddingsDeployment]
@classmethod
def create(cls, *, compat_mapping: Dict[str, str]) -> "AdapterDeployments":
chat_completions = {e.value for e in ChatCompletionDeployment}
embeddings = {e.value for e in EmbeddingsDeployment}
for deployment_id, supported_id in compat_mapping.items():
if deployment_id in chat_completions or deployment_id in embeddings:
log.warning(
f"{deployment_id!r} is one of the Bedrock deployments supported by the adapter already. "
f"Remove {deployment_id!r} from the compatibility mapping to avoid the warning."
)
if (
deployment_id in chat_completions
and supported_id in embeddings
):
raise ValueError(
f"The chat completion deployment {deployment_id!r} is mapped onto the embeddings deployment {supported_id!r}"
)
if (
deployment_id in embeddings
and supported_id in chat_completions
):
raise ValueError(
f"The embeddings deployment {deployment_id!r} is mapped onto the chat completion deployment {supported_id!r}"
)
compat_mapping, chat_completions = _create_deployments(
compat_mapping,
ChatCompletionDeployment,
redirects=CHAT_COMPLETION_REDIRECTS,
)
compat_mapping, embeddings = _create_deployments(
compat_mapping, EmbeddingsDeployment
)
if compat_mapping:
raise ValueError(
f"None of the values in the following compatibility mapping corresponds to a Bedrock deployment supported by the adapter: {json.dumps(compat_mapping)}. "
f"Remap the deployments to the supported Bedrock deployments to fix the error."
)
ret = cls(chat_completions=chat_completions, embeddings=embeddings)
log.debug(f"Adapter deployments: {ret.json()}")
return ret
def _create_deployments(
compat_mapping: Dict[str, str],
upstream_deployments: Iterable[_D],
*,
redirects: Dict[_D, _D] | None = None,
) -> Tuple[Dict[str, str], Dict[str, AdapterDeployment[_D]]]:
compat_mapping = compat_mapping.copy()
redirects = redirects or {}
supported: Dict[str, AdapterDeployment[_D]] = {}
for upstream in upstream_deployments:
deployment_id = upstream.value
supported[deployment_id] = AdapterDeployment.supported(
deployment_id=deployment_id,
upstream=redirects.get(upstream, upstream),
)
compat: Dict[str, AdapterDeployment[_D]] = {}
for deployment_id, supported_deployment_id in list(compat_mapping.items()):
if (
supported_deployment := supported.get(supported_deployment_id)
) is None:
continue
compat_mapping.pop(deployment_id)
compat[deployment_id] = supported_deployment.compat(deployment_id)
return compat_mapping, supported | compat