aidial_adapter_openai/utils/sse_stream.py (54 lines of code) (raw):

import json from typing import Any, AsyncIterator, Mapping from aidial_sdk.exceptions import runtime_server_error from aidial_adapter_openai.exception_handlers import to_adapter_exception from aidial_adapter_openai.utils.log_config import logger DATA_PREFIX = "data: " OPENAI_END_MARKER = "[DONE]" def format_chunk(data: str | Mapping[str, Any]) -> str: if isinstance(data, str): return DATA_PREFIX + data.strip() + "\n\n" else: return DATA_PREFIX + json.dumps(data, separators=(",", ":")) + "\n\n" END_CHUNK = format_chunk(OPENAI_END_MARKER) async def parse_openai_sse_stream( stream: AsyncIterator[bytes], ) -> AsyncIterator[dict]: async for line in stream: try: payload = line.decode("utf-8-sig").lstrip() except Exception: yield runtime_server_error( "Can't decode chunk to a string" ).json_error() return if payload.strip() == "": continue if not payload.startswith(DATA_PREFIX): yield runtime_server_error("Invalid chunk format").json_error() return payload = payload[len(DATA_PREFIX) :] if payload.strip() == OPENAI_END_MARKER: break try: chunk = json.loads(payload) except json.JSONDecodeError: yield runtime_server_error("Can't parse chunk to JSON").json_error() return yield chunk async def to_openai_sse_stream( stream: AsyncIterator[dict], ) -> AsyncIterator[str]: try: async for chunk in stream: yield format_chunk(chunk) except Exception as e: logger.exception( f"caught exception while streaming: {type(e).__module__}.{type(e).__name__}" ) adapter_exception = to_adapter_exception(e) logger.error( f"converted to the adapter exception: {adapter_exception!r}" ) yield format_chunk(adapter_exception.json_error()) yield END_CHUNK