aidial_assistant/json_stream/json_object.py (104 lines of code) (raw):
import json
from collections.abc import AsyncIterator
from typing import Any, Tuple
from typing_extensions import override
from aidial_assistant.json_stream.chunked_char_stream import (
ChunkedCharStream,
skip_whitespaces,
)
from aidial_assistant.json_stream.exceptions import (
unexpected_end_of_stream_error,
unexpected_symbol_error,
)
from aidial_assistant.json_stream.json_node import (
CompoundNode,
JsonNode,
NodeParser,
)
from aidial_assistant.json_stream.json_string import JsonString
from aidial_assistant.utils.text import join_string
class JsonObject(CompoundNode[dict[str, Any], Tuple[str, JsonNode]]):
def __init__(self, source: AsyncIterator[Tuple[str, JsonNode]], pos: int):
super().__init__(source, pos)
self._object = {}
@override
def type(self) -> str:
return "object"
async def get(self, key: str) -> JsonNode:
if key in self._object.keys():
return self._object[key]
async for k, v in self:
if k == key:
return v
raise KeyError(key)
@staticmethod
async def read(
stream: ChunkedCharStream, node_parser: NodeParser
) -> AsyncIterator[Tuple[str, JsonNode]]:
try:
await skip_whitespaces(stream)
char = await anext(stream)
if not JsonObject.starts_with(char):
raise unexpected_symbol_error(char, stream.char_position)
is_comma_expected = False
while True:
await skip_whitespaces(stream)
char = await stream.apeek()
if char == "}":
await stream.askip()
break
if char == ",":
if not is_comma_expected:
raise unexpected_symbol_error(
char, stream.char_position
)
await stream.askip()
is_comma_expected = False
elif JsonString.starts_with(char):
if is_comma_expected:
raise unexpected_symbol_error(
char, stream.char_position
)
key = await join_string(JsonString.read(stream))
await skip_whitespaces(stream)
colon = await anext(stream)
if not colon == ":":
raise unexpected_symbol_error(
colon, stream.char_position
)
value = await node_parser.parse(stream)
yield key, value
if isinstance(value, CompoundNode):
await value.read_to_end()
is_comma_expected = True
else:
raise unexpected_symbol_error(char, stream.char_position)
except StopAsyncIteration:
raise unexpected_end_of_stream_error(stream.char_position)
@override
async def to_chunks(self) -> AsyncIterator[str]:
yield "{"
is_first_entry = True
async for key, value in self:
if not is_first_entry:
yield ", "
yield json.dumps(key)
yield ": "
async for chunk in value.to_chunks():
yield chunk
is_first_entry = False
yield "}"
@override
def value(self) -> dict[str, Any]:
return {k: v.value() for k, v in self._object.items()}
@override
def _accumulate(self, element: Tuple[str, JsonNode]):
self._object[element[0]] = element[1]
@classmethod
def parse(
cls, stream: ChunkedCharStream, node_parser: NodeParser
) -> "JsonObject":
return cls(JsonObject.read(stream, node_parser), stream.char_position)
@staticmethod
def starts_with(char: str) -> bool:
return char == "{"