aidial_sdk/utils/streaming.py (120 lines of code) (raw):
import asyncio
import json
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
Optional,
TypeVar,
Union,
cast,
)
from typing_extensions import assert_never
from aidial_sdk.chat_completion.chunks import BaseChunkWithDefaults
from aidial_sdk.exceptions import HTTPException as DIALException
from aidial_sdk.utils._cancel_scope import CancelScope
from aidial_sdk.utils.logging import log_debug
from aidial_sdk.utils.merge_chunks import cleanup_indices, merge
_DONE_MARKER = "[DONE]"
async def merge_chunks(chunk_stream: AsyncIterator[dict]) -> Dict[str, Any]:
response: Dict[str, Any] = {}
async for chunk in chunk_stream:
response = merge(response, chunk)
for choice in response["choices"]:
choice["message"] = cleanup_indices(choice["delta"])
del choice["delta"]
return response
def _format_chunk(data: Union[dict, str]) -> str:
data = "data: " + (
json.dumps(data, separators=(",", ":"))
if isinstance(data, dict)
else data
)
log_debug(data)
return f"{data}\n\n"
ResponseStream = AsyncIterator[Union[BaseChunkWithDefaults, DIALException]]
ResponseStreamWithStr = AsyncIterator[
Union[BaseChunkWithDefaults, DIALException, str]
]
async def _handle_exceptions_in_block_response(
stream: ResponseStream,
) -> AsyncIterator[dict]:
is_first_chunk = True
async for chunk in stream:
if isinstance(chunk, DIALException):
raise chunk.to_fastapi_exception()
else:
# Setting defaults only for the first chunk to make
# the follow-up merging logic simpler.
yield chunk.to_dict(with_defaults=is_first_chunk)
is_first_chunk = False
async def to_block_response(stream: ResponseStream) -> dict:
chunk_stream = _handle_exceptions_in_block_response(stream)
return await merge_chunks(chunk_stream)
async def to_streaming_response(
stream: ResponseStreamWithStr,
) -> AsyncIterator[str]:
first_chunk = await stream.__anext__()
if isinstance(first_chunk, DIALException):
raise first_chunk.to_fastapi_exception()
def _chunk_to_str(
chunk: Union[BaseChunkWithDefaults, DIALException, str]
) -> str:
if isinstance(chunk, DIALException):
return _format_chunk(chunk.json_error())
elif isinstance(chunk, str):
return chunk
elif isinstance(chunk, BaseChunkWithDefaults):
return _format_chunk(chunk.to_dict(with_defaults=True))
else:
assert_never(chunk)
async def _generator() -> AsyncIterator[str]:
yield _chunk_to_str(first_chunk)
async for chunk in stream:
yield _chunk_to_str(chunk)
yield _format_chunk(_DONE_MARKER)
return _generator()
_T = TypeVar("_T")
_HeartbeatObject = Union[_T, Callable[[], Union[_T, Awaitable[_T]]]]
_HeartbeatCallback = Callable[[], Union[None, Awaitable[None]]]
async def _eval_heartbeat_object(o: _HeartbeatObject[_T]) -> _T:
if callable(o):
result = o()
if isinstance(result, Awaitable):
return await result
return cast(_T, result)
return o
async def _call_heartbeat_callback(c: _HeartbeatCallback) -> None:
result = c()
if isinstance(result, Awaitable):
await result
async def add_heartbeat(
stream: AsyncIterator[_T],
*,
heartbeat_interval: float,
heartbeat_object: Optional[_HeartbeatObject] = None,
heartbeat_callback: Optional[_HeartbeatCallback] = None,
) -> AsyncIterator[_T]:
async with CancelScope() as cs:
chunk_task: Optional[asyncio.Task[_T]] = None
while True:
if chunk_task is None:
chunk_task = cs.create_task(stream.__anext__())
done = (
await asyncio.wait(
[chunk_task],
timeout=heartbeat_interval,
return_when=asyncio.FIRST_COMPLETED,
)
)[0]
if chunk_task in done:
try:
chunk, chunk_task = chunk_task.result(), None
yield chunk
except StopAsyncIteration:
break
else:
if heartbeat_object is not None:
yield await _eval_heartbeat_object(heartbeat_object)
if heartbeat_callback is not None:
await _call_heartbeat_callback(heartbeat_callback)