aidial_sdk/deployment/from_request_mixin.py (91 lines of code) (raw):

from abc import ABC, abstractmethod from json import JSONDecodeError from typing import Any, Mapping, Optional, Type, TypeVar import fastapi from pydantic import Field from aidial_sdk.exceptions import HTTPException as DIALException from aidial_sdk.pydantic_v1 import SecretStr, StrictStr, root_validator from aidial_sdk.utils.pydantic import ExtraForbidModel T = TypeVar("T", bound="FromRequestMixin") class FromRequestMixin(ABC, ExtraForbidModel): @classmethod @abstractmethod async def from_request( cls: Type[T], request: fastapi.Request, deployment_id: str ) -> T: pass @staticmethod @abstractmethod async def get_request_body(request: fastapi.Request) -> Any: pass class FromRequestBasicMixin(FromRequestMixin): @classmethod async def from_request(cls, request: fastapi.Request, deployment_id: str): return cls(**(await cls.get_request_body(request))) @staticmethod async def get_request_body(request: fastapi.Request) -> dict: return await _get_request_json_body(request) class FromRequestDeploymentMixin(FromRequestMixin): api_key_secret: SecretStr jwt_secret: Optional[SecretStr] = None deployment_id: StrictStr api_version: Optional[StrictStr] = None headers: Mapping[StrictStr, StrictStr] original_request: fastapi.Request = Field(..., exclude=True) class Config: arbitrary_types_allowed = True @root_validator(pre=True) def create_secrets(cls, values: dict): if "api_key" in values: if "api_key_secret" not in values: values["api_key_secret"] = SecretStr(values.pop("api_key")) else: raise ValueError( "api_key and api_key_secret cannot be both provided" ) if "jwt" in values: if "jwt_secret" not in values: values["jwt_secret"] = SecretStr(values.pop("jwt")) else: raise ValueError("jwt and jwt_secret cannot be both provided") return values @property def api_key(self) -> str: return self.api_key_secret.get_secret_value() @property def jwt(self) -> Optional[str]: return self.jwt_secret.get_secret_value() if self.jwt_secret else None @classmethod async def from_request(cls, request: fastapi.Request, deployment_id: str): headers = request.headers.mutablecopy() api_key = headers.get("Api-Key") if api_key is None: raise DIALException( status_code=400, type="invalid_request_error", message="Api-Key header is required", ) del headers["Api-Key"] jwt = headers.get("Authorization") del headers["Authorization"] return cls( **(await cls.get_request_body(request)), api_key_secret=SecretStr(api_key), jwt_secret=SecretStr(jwt) if jwt else None, deployment_id=deployment_id, api_version=request.query_params.get("api-version"), headers=headers, original_request=request, ) @staticmethod async def get_request_body(request: fastapi.Request) -> dict: return await _get_request_json_body(request) async def _get_request_json_body(request: fastapi.Request) -> dict: try: return await request.json() except JSONDecodeError as e: raise DIALException( status_code=400, type="invalid_request_error", message=f"The request body isn't valid JSON: {e.msg}", )