aidial_sdk/chat_completion/choice.py (177 lines of code) (raw):

import json from types import TracebackType from typing import Any, Optional, Type, overload from aidial_sdk.chat_completion._types import ChunkQueue from aidial_sdk.chat_completion.choice_base import ChoiceBase from aidial_sdk.chat_completion.chunks import ( AttachmentChunk, BaseChunk, ContentChunk, EndChoiceChunk, FormSchemaChunk, StartChoiceChunk, StateChunk, ) from aidial_sdk.chat_completion.enums import FinishReason from aidial_sdk.chat_completion.function_call import FunctionCall from aidial_sdk.chat_completion.function_tool_call import FunctionToolCall from aidial_sdk.chat_completion.request import Attachment from aidial_sdk.chat_completion.stage import Stage from aidial_sdk.pydantic_v1 import ValidationError from aidial_sdk.utils._attachment import create_attachment from aidial_sdk.utils._content_stream import ContentStream from aidial_sdk.utils.errors import runtime_error from aidial_sdk.utils.logging import log_debug class Choice(ChoiceBase): _queue: ChunkQueue _index: int _last_attachment_index: int _last_stage_index: int _last_tool_call_index: int _has_function_call: bool _opened: bool _closed: bool _state_submitted: bool _schema_submitted: bool _last_finish_reason: Optional[FinishReason] def __init__(self, queue: ChunkQueue, choice_index: int): self._queue = queue self._index = choice_index self._last_attachment_index = 0 self._last_stage_index = 0 self._last_tool_call_index = 0 self._has_function_call = False self._opened = False self._closed = False self._state_submitted = False self._schema_submitted = False self._last_finish_reason = None def __enter__(self): self.open() return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType], ) -> Optional[bool]: self.close() return False def send_chunk(self, chunk: BaseChunk) -> None: log_debug("chunk: " + json.dumps(chunk.to_dict())) self._queue.put_nowait(chunk) @property def index(self) -> int: return self._index @property def opened(self) -> bool: return self._opened @property def closed(self) -> bool: return self._closed @property def has_function_call(self) -> bool: return self._has_function_call def append_content(self, content: str) -> None: if not self._opened: raise runtime_error( "Trying to append content to an unopened choice" ) if self._closed: raise runtime_error("Trying to append content to a closed choice") self.send_chunk(ContentChunk(content, self._index)) self._last_finish_reason = FinishReason.STOP @property def content_stream(self) -> ContentStream: return ContentStream(self) def create_function_tool_call( self, id: str, name: str, arguments: Optional[str] = None ) -> FunctionToolCall: function_tool_call = FunctionToolCall.create_and_send( self, self._last_tool_call_index, id, name, arguments ) self._last_tool_call_index += 1 self._last_finish_reason = FinishReason.TOOL_CALLS return function_tool_call def create_function_call( self, name: str, arguments: Optional[str] = None ) -> FunctionCall: function_call = FunctionCall.create_and_send(self, name, arguments) self._has_function_call = True self._last_finish_reason = FinishReason.FUNCTION_CALL return function_call @overload def add_attachment(self, attachment: Attachment) -> None: ... @overload def add_attachment( self, type: Optional[str] = None, title: Optional[str] = None, data: Optional[str] = None, url: Optional[str] = None, reference_url: Optional[str] = None, reference_type: Optional[str] = None, ) -> None: ... def add_attachment(self, *args, **kwargs) -> None: if not self._opened: raise runtime_error( "Trying to add attachment to an unopened choice" ) if self._closed: raise runtime_error("Trying to add attachment to a closed choice") attachment_chunk = None try: attachment_chunk = AttachmentChunk( choice_index=self._index, attachment_index=self._last_attachment_index, **create_attachment(*args, **kwargs).dict(), ) except ValidationError as e: raise runtime_error(e.errors()[0]["msg"]) self.send_chunk(attachment_chunk) self._last_attachment_index += 1 def set_state(self, state: Any) -> None: if self._state_submitted: raise runtime_error('Trying to set "state" twice') if not self._opened: raise runtime_error("Trying to append state to an unopened choice") if self._closed: raise runtime_error("Trying to append state to a closed choice") self._state_submitted = True self.send_chunk(StateChunk(self._index, state)) def set_form_schema(self, form_schema: Any) -> None: if self._schema_submitted: raise runtime_error("Trying to set form schema twice") if not self._opened: raise runtime_error( "Trying to append form schema to an unopened choice" ) if self._closed: raise runtime_error( "Trying to append form schema to a closed choice" ) self._schema_submitted = True self.send_chunk(FormSchemaChunk(self._index, form_schema)) def create_stage(self, name: Optional[str] = None) -> Stage: if not self._opened: raise runtime_error("Trying to create stage to an unopened choice") if self._closed: raise runtime_error("Trying to create stage to a closed choice") stage = Stage(self._queue, self._index, self._last_stage_index, name) self._last_stage_index += 1 return stage def open(self): if self._opened: raise runtime_error("The choice is already open") self._opened = True self.send_chunk(StartChoiceChunk(choice_index=self._index)) def close(self, finish_reason: Optional[FinishReason] = None) -> None: if not self._opened: raise runtime_error("Trying to close an unopened choice") if self._closed: raise runtime_error( "Trying to close a choice which is already closed" ) reason = finish_reason or self._last_finish_reason or FinishReason.STOP self._closed = True self.send_chunk(EndChoiceChunk(reason, self._index))