aidial_assistant/json_stream/json_node.py (73 lines of code) (raw):

from abc import ABC, abstractmethod from collections.abc import AsyncIterator from typing import Generic, TypeVar from typing_extensions import override from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream from aidial_assistant.json_stream.exceptions import ( unexpected_end_of_stream_error, ) class NodeParser(ABC): @abstractmethod async def parse(self, stream: ChunkedCharStream) -> "JsonNode": pass TValue = TypeVar("TValue") TElement = TypeVar("TElement") class JsonNode(ABC, Generic[TValue]): def __init__(self, pos: int): self._pos = pos @abstractmethod def type(self) -> str: pass @abstractmethod def to_chunks(self) -> AsyncIterator[str]: pass @property def pos(self) -> int: return self._pos @abstractmethod def value(self) -> TValue: pass class CompoundNode( JsonNode[TValue], AsyncIterator[TElement], ABC, Generic[TValue, TElement] ): def __init__(self, source: AsyncIterator[TElement], pos: int): super().__init__(pos) self._source = source @override def __aiter__(self) -> AsyncIterator[TElement]: return self @override async def __anext__(self) -> TElement: result = await anext(self._source) self._accumulate(result) return result @abstractmethod def _accumulate(self, element: TElement): pass async def read_to_end(self): async for _ in self: pass class AtomicNode(JsonNode[TValue], ABC, Generic[TValue]): def __init__(self, raw_data: str, pos: int): super().__init__(pos) self._raw_data = raw_data @override async def to_chunks(self) -> AsyncIterator[str]: yield self._raw_data @classmethod async def parse(cls, stream: ChunkedCharStream) -> "AtomicNode": position = stream.char_position return cls(await AtomicNode._read_all(stream), position) @staticmethod async def _read_all(stream: ChunkedCharStream) -> str: try: raw_data = "" while True: char = await stream.apeek() if char.isspace() or char in ",:[]{}": return raw_data else: raw_data += char await stream.askip() except StopAsyncIteration: raise unexpected_end_of_stream_error(stream.char_position)