aidial_sdk/chat_completion/response.py (192 lines of code) (raw):

import asyncio from time import time from typing import Any, Callable, Coroutine, List from uuid import uuid4 from typing_extensions import assert_never from aidial_sdk.chat_completion._types import ChunkQueue from aidial_sdk.chat_completion.choice import Choice from aidial_sdk.chat_completion.chunks import ( ArbitraryChunk, BaseChunk, BaseChunkWithDefaults, DefaultChunk, DiscardedMessagesChunk, EndChoiceChunk, EndChunk, ExceptionChunk, UsageChunk, UsagePerModelChunk, ) from aidial_sdk.chat_completion.request import Request from aidial_sdk.exceptions import HTTPException as DIALException from aidial_sdk.exceptions import RequestValidationError, RuntimeServerError from aidial_sdk.utils._cancel_scope import CancelScope from aidial_sdk.utils.errors import RUNTIME_ERROR_MESSAGE, runtime_error from aidial_sdk.utils.logging import log_error, log_exception from aidial_sdk.utils.merge_chunks import merge from aidial_sdk.utils.streaming import ResponseStream _Producer = Callable[[Request, "Response"], Coroutine[Any, Any, Any]] class Response: request: Request _queue: ChunkQueue _last_choice_index: int _last_usage_per_model_index: int _generation_started: bool _discarded_messages_generated: bool _usage_generated: bool _default_chunk: DefaultChunk def __init__(self, request: Request): self._queue = asyncio.Queue() self._last_choice_index = 0 self._last_usage_per_model_index = 0 self._generation_started = False self._discarded_messages_generated = False self._usage_generated = False self.request = request self._default_chunk = DefaultChunk( id=str(uuid4()), created=int(time()), object=( "chat.completion.chunk" if self.request.stream else "chat.completion" ), ) @property def n(self) -> int: return self.request.n or 1 @property def stream(self) -> int: return self.request.stream async def _run_producer(self, producer: _Producer): try: await producer(self.request, self) except Exception as e: if isinstance(e, DIALException): dial_exception = e else: log_exception(RUNTIME_ERROR_MESSAGE) dial_exception = RuntimeServerError(RUNTIME_ERROR_MESSAGE) self._queue.put_nowait(ExceptionChunk(dial_exception)) else: self._queue.put_nowait(EndChunk()) async def _generate_stream(self, producer: _Producer) -> ResponseStream: async with CancelScope() as cs: cs.create_task(self._run_producer(producer)) async for chunk in self._generate_chunk_stream(): yield chunk async def _generate_chunk_stream(self) -> ResponseStream: def _create_chunk(chunk: BaseChunk): return BaseChunkWithDefaults( chunk=chunk, defaults=self._default_chunk ) # A list of chunks whose emitting is delayed up until the very last moment delayed_chunks: List[BaseChunk] = [] while True: chunk = await self._queue.get() self._queue.task_done() if isinstance(chunk, BaseChunk): is_last_end_choice_chunk = ( isinstance(chunk, EndChoiceChunk) and chunk.choice_index == self.n - 1 ) is_top_level_chunk = isinstance( chunk, ( UsageChunk, UsagePerModelChunk, DiscardedMessagesChunk, ), ) if is_last_end_choice_chunk or is_top_level_chunk: delayed_chunks.append(chunk) else: yield _create_chunk(chunk) elif isinstance(chunk, (ExceptionChunk, EndChunk)): if delayed_chunks: final_chunk = merge(*[d.to_dict() for d in delayed_chunks]) yield _create_chunk(ArbitraryChunk(chunk=final_chunk)) if isinstance(chunk, ExceptionChunk): yield chunk.exc elif isinstance(chunk, EndChunk): if self._last_choice_index != self.n: log_error("Not all choices were generated") yield RuntimeServerError(RUNTIME_ERROR_MESSAGE) else: assert_never(chunk) return else: assert_never(chunk) def create_choice(self) -> Choice: self._generation_started = True if self._last_choice_index >= self.n: raise runtime_error("Trying to generate more chunks than requested") choice = Choice(self._queue, self._last_choice_index) self._last_choice_index += 1 return choice def create_single_choice(self) -> Choice: if self._last_choice_index > 0: raise runtime_error( "Trying to generate a single choice after choice" ) if self.n > 1: raise RequestValidationError( message=f"{self.request.deployment_id} deployment doesn't support n > 1" ) return self.create_choice() def add_usage_per_model( self, model: str, prompt_tokens: int = 0, completion_tokens: int = 0 ): self._generation_started = True if self._last_choice_index != self.n: raise runtime_error( 'Trying to set "usage_per_model" before generating all choices', ) self._queue.put_nowait( UsagePerModelChunk( self._last_usage_per_model_index, model, prompt_tokens, completion_tokens, ) ) self._last_usage_per_model_index += 1 def set_discarded_messages(self, discarded_messages: List[int]): self._generation_started = True if self._discarded_messages_generated: raise runtime_error('Trying to set "discarded_messages" twice') if self._last_choice_index != self.n: raise runtime_error( 'Trying to set "discarded_messages" before generating all choices', ) self._discarded_messages_generated = True self._queue.put_nowait(DiscardedMessagesChunk(discarded_messages)) def set_usage(self, prompt_tokens: int = 0, completion_tokens: int = 0): self._generation_started = True if self._usage_generated: raise runtime_error('Trying to set "usage" twice') if self._last_choice_index != self.n: raise runtime_error( 'Trying to set "usage" before generating all choices', ) self._usage_generated = True self._queue.put_nowait(UsageChunk(prompt_tokens, completion_tokens)) async def aflush(self): await self._queue.join() def set_created(self, created: int): if self._generation_started: raise runtime_error( 'Trying to set "created" after start of generation' ) self._default_chunk["created"] = created def set_model(self, model: str): if self._generation_started: raise runtime_error( 'Trying to set "model" after start of generation' ) self._default_chunk["model"] = model def set_response_id(self, response_id: str): if self._generation_started: raise runtime_error( 'Trying to set "response_id" after start of generation', ) self._default_chunk["id"] = response_id