aidial_sdk/header_propagator.py (104 lines of code) (raw):
import types
from contextvars import ContextVar
from typing import MutableMapping, Optional
import wrapt
from fastapi import FastAPI
from starlette.types import ASGIApp, Receive, Scope, Send
class FastAPIMiddleware:
def __init__(
self,
app: ASGIApp,
api_key: ContextVar[Optional[str]],
) -> None:
self.app = app
self.api_key = api_key
async def __call__(
self, scope: Scope, receive: Receive, send: Send
) -> None:
for header in scope.get("headers") or []:
if header[0] == b"api-key":
self.api_key.set(header[1].decode("utf-8"))
await self.app(scope, receive, send)
class HeaderPropagator:
_app: FastAPI
_dial_url: str
_api_key: ContextVar[Optional[str]]
_enabled: bool
def __init__(self, app: FastAPI, dial_url: str):
self._app = app
self._dial_url = dial_url
self._api_key: ContextVar[Optional[str]] = ContextVar(
"api_key", default=None
)
self._enabled = False
def enable(self):
if self._enabled:
return
self._instrument_fast_api(self._app)
self._instrument_aiohttp()
self._instrument_httpx()
self._instrument_requests()
self._enabled = True
def _instrument_fast_api(self, app: FastAPI):
app.add_middleware(FastAPIMiddleware, api_key=self._api_key)
def _instrument_aiohttp(self):
try:
import aiohttp
except ImportError:
return
async def _on_request_start(
session: aiohttp.ClientSession,
trace_config_ctx: types.SimpleNamespace,
params: aiohttp.TraceRequestStartParams,
):
self._modify_headers(str(params.url), params.headers)
def instrumented_init(wrapped, instance, args, kwargs):
trace_config = aiohttp.TraceConfig()
trace_config.on_request_start.append(_on_request_start)
trace_configs = list(kwargs.get("trace_configs") or [])
trace_configs.append(trace_config)
kwargs["trace_configs"] = trace_configs
return wrapped(*args, **kwargs)
wrapt.wrap_function_wrapper(
aiohttp.ClientSession, "__init__", instrumented_init
)
def _instrument_requests(self):
try:
import requests
except ImportError:
return
def instrumented_send(wrapped, instance, args, kwargs):
request: requests.PreparedRequest = args[0]
self._modify_headers(request.url or "", request.headers)
return wrapped(*args, **kwargs)
wrapt.wrap_function_wrapper(requests.Session, "send", instrumented_send)
def _instrument_httpx(self):
try:
import httpx
except ImportError:
return
def instrumented_build_request(wrapped, instance, args, kwargs):
request: httpx.Request = wrapped(*args, **kwargs)
self._modify_headers(str(request.url), request.headers)
return request
wrapt.wrap_function_wrapper(
httpx.Client, "build_request", instrumented_build_request
)
wrapt.wrap_function_wrapper(
httpx.AsyncClient, "build_request", instrumented_build_request
)
def _modify_headers(
self, url: str, headers: MutableMapping[str, str]
) -> None:
if url.startswith(self._dial_url):
api_key = self._api_key.get()
if api_key:
old_api_key = headers.get("api-key")
old_authz = headers.get("Authorization")
if (
old_api_key
and old_authz
and old_authz == f"Bearer {old_api_key}"
):
headers["Authorization"] = f"Bearer {api_key}"
headers["api-key"] = api_key