aidial_assistant/json_stream/json_array.py (79 lines of code) (raw):

from collections.abc import AsyncIterator from typing import Any from typing_extensions import override from aidial_assistant.json_stream.chunked_char_stream import ( ChunkedCharStream, skip_whitespaces, ) from aidial_assistant.json_stream.exceptions import ( unexpected_end_of_stream_error, unexpected_symbol_error, ) from aidial_assistant.json_stream.json_node import ( CompoundNode, JsonNode, NodeParser, ) class JsonArray(CompoundNode[list[Any], JsonNode]): def __init__(self, source: AsyncIterator[JsonNode], pos: int): super().__init__(source, pos) self._array: list[JsonNode] = [] @override def type(self) -> str: return "array" @staticmethod async def read( stream: ChunkedCharStream, node_parser: NodeParser ) -> AsyncIterator[JsonNode]: try: await skip_whitespaces(stream) char = await anext(stream) if not JsonArray.starts_with(char): raise unexpected_symbol_error(char, stream.char_position) is_comma_expected = False while True: await skip_whitespaces(stream) char = await stream.apeek() if char == "]": await stream.askip() break if char == ",": if not is_comma_expected: raise unexpected_symbol_error( char, stream.char_position ) await stream.askip() is_comma_expected = False else: value = await node_parser.parse(stream) yield value if isinstance(value, CompoundNode): await value.read_to_end() is_comma_expected = True except StopAsyncIteration: raise unexpected_end_of_stream_error(stream.char_position) @override async def to_chunks(self) -> AsyncIterator[str]: yield "[" is_first_element = True async for value in self: if not is_first_element: yield ", " async for chunk in value.to_chunks(): yield chunk is_first_element = False yield "]" @override def value(self) -> list[JsonNode]: return [item.value() for item in self._array] @override def _accumulate(self, element: JsonNode): self._array.append(element) @classmethod def parse( cls, stream: ChunkedCharStream, node_parser: NodeParser ) -> "JsonArray": return cls(JsonArray.read(stream, node_parser), stream.char_position) @staticmethod def starts_with(char: str) -> bool: return char == "["