aidial_adapter_openai/utils/auth.py (64 lines of code) (raw):
import os
import time
from typing import Mapping, Optional, TypedDict
from aidial_sdk.exceptions import HTTPException as DialException
from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from azure.identity.aio import DefaultAzureCredential
from fastapi import Request
from pydantic import BaseModel
from aidial_adapter_openai.utils.log_config import logger
default_credential = DefaultAzureCredential()
access_token: AccessToken | None = None
EXPIRATION_WINDOW_IN_SEC: int = int(
os.getenv("ACCESS_TOKEN_EXPIRATION_WINDOW", 10)
)
AZURE_OPEN_AI_SCOPE: str = os.getenv(
"AZURE_OPEN_AI_SCOPE", "https://cognitiveservices.azure.com/.default"
)
async def get_api_key() -> str:
now = int(time.time())
global access_token
if (
access_token is None
or now + EXPIRATION_WINDOW_IN_SEC > access_token.expires_on
):
try:
access_token = await default_credential.get_token(
AZURE_OPEN_AI_SCOPE
)
except ClientAuthenticationError as e:
logger.error(
f"Default Azure credential failed with the error: {e.message}"
)
raise DialException("Authentication failed", 401, "Unauthorized")
return access_token.token
class OpenAICreds(TypedDict, total=False):
api_key: str
azure_ad_token: str
async def get_credentials(request: Request) -> OpenAICreds:
api_key = request.headers.get("X-UPSTREAM-KEY")
if api_key is None:
return {"azure_ad_token": await get_api_key()}
else:
return {"api_key": api_key}
def get_auth_headers(creds: OpenAICreds) -> dict[str, str]:
if "api_key" in creds:
return {"api-key": creds["api_key"]}
if "azure_ad_token" in creds:
return {"Authorization": f"Bearer {creds['azure_ad_token']}"}
raise ValueError("Invalid credentials")
class Auth(BaseModel):
name: str
value: str
@property
def headers(self) -> dict[str, str]:
return {self.name: self.value}
@classmethod
def from_headers(
cls, name: str, headers: Mapping[str, str]
) -> Optional["Auth"]:
value = headers.get(name)
if value is None:
return None
return cls(name=name, value=value)