aidial_adapter_openai/utils/streaming.py (222 lines of code) (raw):

import logging from time import time from typing import Any, AsyncIterator, Callable, Optional, TypeVar from uuid import uuid4 from aidial_sdk.exceptions import HTTPException as DialException from aidial_sdk.utils.merge_chunks import merge_chat_completion_chunks from fastapi.responses import JSONResponse, Response, StreamingResponse from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from pydantic import BaseModel from aidial_adapter_openai.utils.chat_completion_response import ( ChatCompletionResponse, ChatCompletionStreamingChunk, ) from aidial_adapter_openai.utils.log_config import logger from aidial_adapter_openai.utils.sse_stream import to_openai_sse_stream def generate_id() -> str: return "chatcmpl-" + str(uuid4()) def generate_created() -> int: return int(time()) def build_chunk( id: str, finish_reason: Optional[str], message: Any, created: int, is_stream: bool, **extra, ) -> dict: message_key = "delta" if is_stream else "message" object_name = "chat.completion.chunk" if is_stream else "chat.completion" return { "id": id, "object": object_name, "created": created, "choices": [ { "index": 0, message_key: message, "finish_reason": finish_reason, } ], **extra, } async def generate_stream( *, stream: AsyncIterator[dict], get_prompt_tokens: Callable[[], int], tokenize_response: Callable[[ChatCompletionResponse], int], deployment: str, discarded_messages: Optional[list[int]], eliminate_empty_choices: bool, ) -> AsyncIterator[dict]: empty_chunk = build_chunk( id=generate_id(), created=generate_created(), model=deployment, is_stream=True, message={}, finish_reason=None, ) def set_usage(chunk: dict | None, resp: ChatCompletionResponse) -> dict: chunk = chunk or empty_chunk # Do not fail the whole response if tokenization has failed try: completion_tokens = tokenize_response(resp) prompt_tokens = get_prompt_tokens() except Exception as e: logger.exception( f"caught exception while tokenization: {type(e).__module__}.{type(e).__name__}. " "The tokenization has failed, therefore, the usage won't be reported." ) else: chunk["usage"] = { "completion_tokens": completion_tokens, "prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens + completion_tokens, } return chunk def set_finish_reason(chunk: dict | None, finish_reason: str) -> dict: chunk = chunk or empty_chunk chunk["choices"] = chunk.get("choices") or [{"index": 0, "delta": {}}] chunk["choices"][0]["finish_reason"] = finish_reason return chunk def set_discarded_messages(chunk: dict | None, indices: list[int]) -> dict: chunk = chunk or empty_chunk chunk["statistics"] = {"discarded_messages": indices} return chunk last_chunk = None buffer_chunk = None response_snapshot = ChatCompletionStreamingChunk() error: Exception | None = None try: async for chunk in stream: response_snapshot.merge(chunk) if buffer_chunk is not None: chunk = merge_chat_completion_chunks(chunk, buffer_chunk) buffer_chunk = None choices = chunk.get("choices") or [] # Azure OpenAI returns an empty list of choices as a first chunk # when content filtering is enabled for a corresponding deployment. # The safety rating of the request is reported in this first chunk. # Here we withhold such a chunk and merge it later with a follow-up chunk. if len(choices) == 0 and eliminate_empty_choices: buffer_chunk = chunk else: if last_chunk is not None: yield last_chunk last_chunk = chunk except Exception as e: logger.exception( f"caught exception while streaming: {type(e).__module__}.{type(e).__name__}" ) error = e if last_chunk is not None and buffer_chunk is not None: last_chunk = merge_chat_completion_chunks(last_chunk, buffer_chunk) if discarded_messages is not None: last_chunk = set_discarded_messages(last_chunk, discarded_messages) if response_snapshot.usage is None and ( not error or response_snapshot.has_messages ): last_chunk = set_usage(last_chunk, response_snapshot) if not error: has_finish_reason = response_snapshot.has_finish_reason if response_snapshot.is_empty: logger.warning("Received 0 chunks") elif not has_finish_reason: logger.warning("Didn't receive chunk with the finish reason") if not has_finish_reason: last_chunk = set_finish_reason(last_chunk, "length") if response_snapshot.usage is None: last_chunk = set_usage(last_chunk, response_snapshot) if last_chunk: yield last_chunk if error: raise error def create_stage_chunk(name: str, content: str, stream: bool) -> dict: id = generate_id() created = generate_created() stage = { "index": 0, "name": name, "content": content, "status": "completed", } custom_content = {"stages": [stage]} return build_chunk( id, "stop", { "role": "assistant", "content": "", "custom_content": custom_content, }, created, stream, usage={ "completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0, }, ) def create_response_from_chunk( chunk: dict, exc: DialException | None, stream: bool ) -> AsyncIterator[dict] | Response: if not stream: if exc is not None: return exc.to_fastapi_response() else: return JSONResponse(content=chunk) async def generator() -> AsyncIterator[dict]: yield chunk if exc is not None: yield exc.json_error() return generator() def block_response_to_streaming_chunk(response: dict) -> dict: response["object"] = "chat.completion.chunk" for choice in response.get("choices") or []: if message := choice.get("message"): choice["delta"] = message del choice["message"] return response def create_server_response( emulate_stream: bool, response: AsyncIterator[dict] | dict | BaseModel | Response, ) -> Response: def block_to_stream(block: dict) -> AsyncIterator[dict]: async def stream(): yield block_response_to_streaming_chunk(block) return stream() def stream_to_response(stream: AsyncIterator[dict]) -> Response: return StreamingResponse( to_openai_sse_stream(stream), media_type="text/event-stream", ) def block_to_response(block: dict) -> Response: if emulate_stream: return stream_to_response(block_to_stream(block)) else: return JSONResponse(block) if isinstance(response, AsyncIterator): return stream_to_response(response) if isinstance(response, dict): return block_to_response(response) if isinstance(response, BaseModel): return block_to_response(response.dict()) return response T = TypeVar("T") V = TypeVar("V") async def prepend_to_stream( value: T, iterator: AsyncIterator[T] ) -> AsyncIterator[T]: yield value async for item in iterator: yield item async def map_stream( func: Callable[[T], Optional[V]], iterator: AsyncIterator[T] ) -> AsyncIterator[V]: async for item in iterator: new_item = func(item) if new_item is not None: yield new_item def debug_print(title: str, chunk: dict) -> None: if logger.isEnabledFor(logging.DEBUG): logger.debug(f"{title}: {chunk}") def chunk_to_dict(chunk: ChatCompletionChunk) -> dict: dict = chunk.to_dict() debug_print("chunk", dict) return dict