aidial_assistant/json_stream/json_string.py (81 lines of code) (raw):
import json
from collections.abc import AsyncIterator
from typing_extensions import override
from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream
from aidial_assistant.json_stream.exceptions import (
JsonParsingException,
unexpected_end_of_stream_error,
unexpected_symbol_error,
)
from aidial_assistant.json_stream.json_node import CompoundNode
class JsonString(CompoundNode[str, str]):
def __init__(self, source: AsyncIterator[str], pos: int):
super().__init__(source, pos)
self._buffer = ""
@override
def type(self) -> str:
return "string"
@override
def _accumulate(self, element: str):
self._buffer += element
@override
async def to_chunks(self) -> AsyncIterator[str]:
yield '"'
async for chunk in self:
yield json.dumps(chunk)[1:-1]
yield '"'
@override
def value(self) -> str:
return self._buffer
@classmethod
def parse(cls, stream: ChunkedCharStream) -> "JsonString":
return cls(JsonString.read(stream), stream.char_position)
@staticmethod
async def read(stream: ChunkedCharStream) -> AsyncIterator[str]:
try:
char = await anext(stream)
if not JsonString.starts_with(char):
raise unexpected_symbol_error(char, stream.char_position)
result = ""
chunk_position = stream.chunk_position
while True:
char = await anext(stream)
if char == '"':
break
result += (
await JsonString._escape(stream) if char == "\\" else char
)
if chunk_position != stream.chunk_position:
yield result
result = ""
chunk_position = stream.chunk_position
except StopAsyncIteration:
raise unexpected_end_of_stream_error(stream.char_position)
if result:
yield result
@staticmethod
async def _escape(stream: ChunkedCharStream) -> str:
char = await anext(stream)
if char == "u":
unicode_sequence = "".join([await anext(stream) for _ in range(4)]) # type: ignore
return chr(int(unicode_sequence, 16))
if char in '"\\/':
return char
if char == "b":
return "\b"
elif char == "f":
return "\f"
elif char == "n":
return "\n"
elif char == "r":
return "\r"
elif char == "t":
return "\t"
else:
raise JsonParsingException(
f"Unexpected escape sequence: \\{char}.",
stream.char_position - 1,
)
@staticmethod
def starts_with(char: str) -> bool:
return char == '"'