aidial_adapter_openai/app_config.py (100 lines of code) (raw):

import json import os from typing import Callable, Dict, List from pydantic import BaseModel from aidial_adapter_openai.constant import ChatCompletionDeploymentType from aidial_adapter_openai.utils.env import get_env_bool from aidial_adapter_openai.utils.json import remove_nones from aidial_adapter_openai.utils.log_config import logger class ApplicationConfig(BaseModel): MODEL_ALIASES: Dict[str, str] = {} DALLE3_DEPLOYMENTS: List[str] = [] GPT4_VISION_DEPLOYMENTS: List[str] = [] MISTRAL_DEPLOYMENTS: List[str] = [] DATABRICKS_DEPLOYMENTS: List[str] = [] GPT4O_DEPLOYMENTS: List[str] = [] GPT4O_MINI_DEPLOYMENTS: List[str] = [] AZURE_AI_VISION_DEPLOYMENTS: List[str] = [] API_VERSIONS_MAPPING: Dict[str, str] = {} COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES: Dict[str, str] = {} DALLE3_AZURE_API_VERSION: str = "2024-02-01" NON_STREAMING_DEPLOYMENTS: List[str] = [] ELIMINATE_EMPTY_CHOICES: bool = False DEPLOYMENT_TYPE_MAP: Dict[ ChatCompletionDeploymentType, Callable[["ApplicationConfig"], List[str]] ] = { ChatCompletionDeploymentType.DALLE3: lambda config: config.DALLE3_DEPLOYMENTS, ChatCompletionDeploymentType.GPT4_VISION: lambda config: config.GPT4_VISION_DEPLOYMENTS, ChatCompletionDeploymentType.MISTRAL: lambda config: config.MISTRAL_DEPLOYMENTS, ChatCompletionDeploymentType.DATABRICKS: lambda config: config.DATABRICKS_DEPLOYMENTS, ChatCompletionDeploymentType.GPT4O: lambda config: config.GPT4O_DEPLOYMENTS, ChatCompletionDeploymentType.GPT4O_MINI: lambda config: config.GPT4O_MINI_DEPLOYMENTS, } def get_chat_completion_deployment_type( self, deployment_id: str ) -> ChatCompletionDeploymentType: for deployment_type, config_getter in self.DEPLOYMENT_TYPE_MAP.items(): if deployment_id in config_getter(self): return deployment_type return ChatCompletionDeploymentType.GPT_TEXT_ONLY def add_deployment( self, deployment_id: str, deployment_type: ChatCompletionDeploymentType ): if deployment_type == ChatCompletionDeploymentType.GPT_TEXT_ONLY: return config_getter = self.DEPLOYMENT_TYPE_MAP[deployment_type] config_getter(self).append(deployment_id) @classmethod def from_env(cls) -> "ApplicationConfig": def _parse_env_deployments(deployments_key: str) -> List[str] | None: deployments_value = os.getenv(deployments_key) if deployments_value is None: return None return list(map(str.strip, (deployments_value).split(","))) def _parse_env_dict(key: str) -> Dict[str, str] | None: value = os.getenv(key) return json.loads(value) if value else None def _parse_eliminate_empty_choices() -> bool | None: old_name = "FIX_STREAMING_ISSUES_IN_NEW_API_VERSIONS" new_name = "ELIMINATE_EMPTY_CHOICES" if old_name in os.environ: logger.warning( f"{old_name} environment variable is deprecated. Use {new_name} instead." ) return get_env_bool(old_name) elif new_name in os.environ: return get_env_bool(new_name) return None deployment_fields = { deployment_key: _parse_env_deployments(deployment_key) for deployment_key in ( "DALLE3_DEPLOYMENTS", "GPT4_VISION_DEPLOYMENTS", "MISTRAL_DEPLOYMENTS", "DATABRICKS_DEPLOYMENTS", "GPT4O_DEPLOYMENTS", "GPT4O_MINI_DEPLOYMENTS", "AZURE_AI_VISION_DEPLOYMENTS", "NON_STREAMING_DEPLOYMENTS", ) } dict_fields = { key: _parse_env_dict(key) for key in ( "MODEL_ALIASES", "API_VERSIONS_MAPPING", "COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES", ) } return cls( **remove_nones( { **deployment_fields, **dict_fields, "DALLE3_AZURE_API_VERSION": os.getenv( "DALLE3_AZURE_API_VERSION" ), "ELIMINATE_EMPTY_CHOICES": _parse_eliminate_empty_choices(), } ), )