aidial_sdk/application.py (221 lines of code) (raw):
import logging.config
import re
import warnings
from logging import Filter, LogRecord
from typing import Any, Callable, Coroutine, Literal, Optional, Type, TypeVar
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from aidial_sdk._errors import (
dial_exception_handler,
fastapi_exception_handler,
pydantic_validation_exception_handler,
)
from aidial_sdk.chat_completion.base import ChatCompletion
from aidial_sdk.chat_completion.request import Request as ChatCompletionRequest
from aidial_sdk.chat_completion.response import (
Response as ChatCompletionResponse,
)
from aidial_sdk.deployment.configuration import ConfigurationRequest
from aidial_sdk.deployment.from_request_mixin import FromRequestMixin
from aidial_sdk.deployment.rate import RateRequest
from aidial_sdk.deployment.tokenize import TokenizeRequest
from aidial_sdk.deployment.truncate_prompt import TruncatePromptRequest
from aidial_sdk.embeddings.base import Embeddings
from aidial_sdk.embeddings.request import Request as EmbeddingsRequest
from aidial_sdk.exceptions import HTTPException as DIALException
from aidial_sdk.header_propagator import HeaderPropagator
from aidial_sdk.pydantic_v1 import ValidationError
from aidial_sdk.telemetry.types import TelemetryConfig
from aidial_sdk.utils._reflection import get_method_implementation
from aidial_sdk.utils.log_config import LogConfig
from aidial_sdk.utils.logging import log_debug, set_log_deployment
from aidial_sdk.utils.streaming import (
add_heartbeat,
to_block_response,
to_streaming_response,
)
logging.config.dictConfig(LogConfig().dict())
RequestType = TypeVar("RequestType", bound=FromRequestMixin)
class PathFilter(Filter):
path: str
def __init__(self, path: str) -> None:
super().__init__(name="")
self.path = path
def filter(self, record: LogRecord):
return not re.search(f"(\\s+){self.path}(\\s+)", record.getMessage())
class DIALApp(FastAPI):
def __init__(
self,
dial_url: Optional[str] = None,
propagate_auth_headers: bool = False,
telemetry_config: Optional[TelemetryConfig] = None,
add_healthcheck: bool = False,
**kwargs,
):
if "propagation_auth_headers" in kwargs:
warnings.warn(
"The 'propagation_auth_headers' parameter is deprecated. "
"Use 'propagate_auth_headers' instead.",
DeprecationWarning,
stacklevel=2,
)
propagate_auth_headers = kwargs.pop("propagation_auth_headers")
super().__init__(**kwargs)
if telemetry_config is not None:
self.configure_telemetry(telemetry_config)
if propagate_auth_headers:
if not dial_url:
raise ValueError(
"dial_url is required if propagation auth headers is enabled"
)
HeaderPropagator(self, dial_url).enable()
if add_healthcheck:
path = "/health"
self.add_api_route(path, DIALApp._healthcheck, methods=["GET"])
logging.getLogger("uvicorn.access").addFilter(PathFilter(path))
self.add_exception_handler(
ValidationError, pydantic_validation_exception_handler
)
self.add_exception_handler(HTTPException, fastapi_exception_handler)
self.add_exception_handler(DIALException, dial_exception_handler)
def configure_telemetry(self, config: TelemetryConfig):
try:
from aidial_sdk.telemetry.init import init_telemetry
except ImportError:
raise ValueError(
"Missing telemetry dependencies. "
"Install the package with the extras: aidial-sdk[telemetry]"
)
init_telemetry(app=self, config=config)
def add_embeddings(
self, deployment_name: str, impl: Embeddings
) -> "DIALApp":
self.add_api_route(
f"/openai/deployments/{deployment_name}/embeddings",
self._embeddings(deployment_name, impl),
methods=["POST"],
)
return self
def add_chat_completion(
self,
deployment_name: str,
impl: ChatCompletion,
*,
heartbeat_interval: Optional[float] = None,
) -> "DIALApp":
self.add_api_route(
f"/openai/deployments/{deployment_name}/chat/completions",
self._chat_completion(
deployment_name,
impl,
heartbeat_interval=heartbeat_interval,
),
methods=["POST"],
)
self.add_api_route(
f"/openai/deployments/{deployment_name}/rate",
self._rate_response(deployment_name, impl),
methods=["POST"],
)
if endpoint_impl := get_method_implementation(impl, "tokenize"):
self.add_api_route(
f"/openai/deployments/{deployment_name}/tokenize",
self._endpoint_factory(
deployment_name, endpoint_impl, "tokenize", TokenizeRequest
),
methods=["POST"],
)
if endpoint_impl := get_method_implementation(impl, "truncate_prompt"):
self.add_api_route(
f"/openai/deployments/{deployment_name}/truncate_prompt",
self._endpoint_factory(
deployment_name,
endpoint_impl,
"truncate_prompt",
TruncatePromptRequest,
),
methods=["POST"],
)
if endpoint_impl := get_method_implementation(impl, "configuration"):
self.add_api_route(
f"/openai/deployments/{deployment_name}/configuration",
self._endpoint_factory(
deployment_name,
endpoint_impl,
"configuration",
ConfigurationRequest,
),
methods=["GET"],
)
return self
def _endpoint_factory(
self,
deployment_id: str,
endpoint_impl: Callable[[RequestType], Coroutine[Any, Any, Any]],
endpoint: Literal["tokenize", "truncate_prompt", "configuration"],
request_type: Type["RequestType"],
):
async def _handler(original_request: Request) -> Response:
set_log_deployment(deployment_id)
request = await request_type.from_request(
original_request, deployment_id
)
log_debug(f"request[{endpoint}]: {request}")
response = await endpoint_impl(request)
response_json = response.dict()
log_debug(f"response[{endpoint}]: {response_json}")
return JSONResponse(content=response_json)
return _handler
def _rate_response(self, deployment_id: str, impl: ChatCompletion):
async def _handler(original_request: Request):
set_log_deployment(deployment_id)
request = await RateRequest.from_request(
original_request, deployment_id
)
await impl.rate_response(request)
return Response(status_code=200)
return _handler
def _chat_completion(
self,
deployment_id: str,
impl: ChatCompletion,
*,
heartbeat_interval: Optional[float],
):
async def _handler(original_request: Request):
set_log_deployment(deployment_id)
request = await ChatCompletionRequest.from_request(
original_request, deployment_id
)
response = ChatCompletionResponse(request)
stream = response._generate_stream(impl.chat_completion)
if request.stream:
if heartbeat_interval:
stream = add_heartbeat(
stream,
heartbeat_interval=heartbeat_interval,
heartbeat_callback=lambda: log_debug("heartbeat"),
heartbeat_object=": heartbeat\n\n",
)
return StreamingResponse(
await to_streaming_response(stream),
media_type="text/event-stream",
)
else:
response_json = await to_block_response(stream)
log_debug(f"response: {response_json}")
return JSONResponse(content=response_json)
return _handler
def _embeddings(self, deployment_id: str, impl: Embeddings):
async def _handler(original_request: Request):
set_log_deployment(deployment_id)
request = await EmbeddingsRequest.from_request(
original_request, deployment_id
)
response = await impl.embeddings(request)
response_json = response.dict()
return JSONResponse(content=response_json)
return _handler
@staticmethod
async def _healthcheck() -> JSONResponse:
return JSONResponse(content={"status": "ok"})