aidial_adapter_dial/utils/sse_stream.py (30 lines of code) (raw):

import json import logging from typing import Any, AsyncIterator, Mapping from aidial_adapter_dial.utils.exceptions import ( to_dial_exception, to_json_content, ) DATA_PREFIX = "data: " OPENAI_END_MARKER = "[DONE]" def format_chunk(data: str | Mapping[str, Any]) -> str: if isinstance(data, str): return DATA_PREFIX + data.strip() + "\n\n" else: return DATA_PREFIX + json.dumps(data, separators=(",", ":")) + "\n\n" END_CHUNK = format_chunk(OPENAI_END_MARKER) log = logging.getLogger(__name__) async def to_openai_sse_stream( stream: AsyncIterator[dict], ) -> AsyncIterator[str]: try: async for chunk in stream: yield format_chunk(chunk) except Exception as e: log.exception( f"caught exception while streaming: {type(e).__module__}.{type(e).__name__}" ) dial_exception = to_dial_exception(e) error_chunk = to_json_content(dial_exception) yield format_chunk(error_chunk) yield END_CHUNK