aidial_interceptors_sdk/utils/streaming.py (59 lines of code) (raw):

import logging from typing import Any, AsyncIterator, Callable, Optional, TypeVar from aidial_sdk.exceptions import HTTPException as DialException from aidial_interceptors_sdk.chat_completion.annotated_value import ( AnnotatedChunk, AnnotatedException, AnnotatedValue, Annotation, ) from aidial_interceptors_sdk.utils._exceptions import to_dial_exception _log = logging.getLogger(__name__) _T = TypeVar("_T") _V = TypeVar("_V") async def materialize_streaming_errors( stream: AsyncIterator[dict], ) -> AsyncIterator[dict | DialException]: try: async for chunk in stream: yield chunk except Exception as e: _log.exception( f"caught exception while streaming: {type(e).__module__}.{type(e).__name__}" ) yield to_dial_exception(e) def annotate_stream( annotation: Annotation, stream: AsyncIterator[dict | DialException] ) -> AsyncIterator[AnnotatedValue]: def _annotate(value: dict | DialException) -> AnnotatedValue: if isinstance(value, dict): return AnnotatedChunk(chunk=value, annotation=annotation) else: return AnnotatedException(error=value, annotation=annotation) return map_stream(_annotate, stream) # TODO: add to SDK as a inverse of cleanup_indices def _add_indices(chunk: Any) -> Any: if isinstance(chunk, list): ret = [] for idx, elem in enumerate(chunk, start=1): if isinstance(elem, dict) and "index" not in elem: elem = {**elem, "index": idx} ret.append(_add_indices(elem)) return ret if isinstance(chunk, dict): return {key: _add_indices(value) for key, value in chunk.items()} return chunk # TODO: add to SDK as an inverse of merge_chunks def block_response_to_streaming_chunk(response: dict) -> dict: for choice in response["choices"]: choice["delta"] = choice["message"] del choice["message"] _add_indices(choice["delta"]) return response async def map_stream( func: Callable[[_T], Optional[_V]], iterator: AsyncIterator[_T] ) -> AsyncIterator[_V]: async for item in iterator: new_item = func(item) if new_item is not None: yield new_item async def singleton_stream(item: _T) -> AsyncIterator[_T]: yield item