aidial_adapter_openai/utils/parsers.py (83 lines of code) (raw):

import re from abc import ABC, abstractmethod from json import JSONDecodeError from typing import Any, Dict, TypedDict from aidial_sdk.exceptions import InvalidRequestError from fastapi import Request from openai import AsyncAzureOpenAI, AsyncOpenAI, Timeout from pydantic import BaseModel from aidial_adapter_openai.utils.http_client import get_http_client class OpenAIParams(TypedDict, total=False): api_key: str azure_ad_token: str api_version: str timeout: Timeout class Endpoint(ABC): @abstractmethod def get_client(self, params: OpenAIParams) -> AsyncOpenAI: pass # Retries are handled on the DIAL Core side _MAX_RETRIES = 0 class AzureOpenAIEndpoint(BaseModel): azure_endpoint: str azure_deployment: str def get_client(self, params: OpenAIParams) -> AsyncAzureOpenAI: return AsyncAzureOpenAI( azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, api_key=params.get("api_key"), azure_ad_token=params.get("azure_ad_token"), api_version=params.get("api_version"), timeout=params.get("timeout"), max_retries=_MAX_RETRIES, http_client=get_http_client(), ) class OpenAIEndpoint(BaseModel): base_url: str def get_client(self, params: OpenAIParams) -> AsyncOpenAI: return AsyncOpenAI( base_url=self.base_url, api_key=params.get("api_key"), timeout=params.get("timeout"), max_retries=_MAX_RETRIES, http_client=get_http_client(), ) def _parse_endpoint( name, endpoint ) -> AzureOpenAIEndpoint | OpenAIEndpoint | None: if azure_match := re.search( f"(.+?)/openai/deployments/(.+?)/{name}", endpoint ): return AzureOpenAIEndpoint( azure_endpoint=azure_match[1], azure_deployment=azure_match[2], ) elif openai_match := re.search(f"(.+?)/{name}", endpoint): return OpenAIEndpoint(base_url=openai_match[1]) else: return None class EndpointParser(BaseModel): name: str def parse(self, endpoint: str) -> AzureOpenAIEndpoint | OpenAIEndpoint: if result := _parse_endpoint(self.name, endpoint): return result raise InvalidRequestError("Invalid upstream endpoint format") class CompletionsParser(BaseModel): def parse( self, endpoint: str ) -> AzureOpenAIEndpoint | OpenAIEndpoint | None: if "/chat/completions" in endpoint: return None return _parse_endpoint("completions", endpoint) chat_completions_parser = EndpointParser(name="chat/completions") embeddings_parser = EndpointParser(name="embeddings") completions_parser = CompletionsParser() async def parse_body(request: Request) -> Dict[str, Any]: try: data = await request.json() except JSONDecodeError as e: raise InvalidRequestError( "Your request contained invalid JSON: " + str(e) ) if not isinstance(data, dict): raise InvalidRequestError(str(data) + " is not of type 'object'") return data