aidial_sdk/utils/merge_chunks.py (143 lines of code) (raw):
import copy
from typing import Any, List, TypeVar, Union, cast
T = TypeVar("T")
Path = List[Union[int, str]]
LIST_OF_DICTS_ERROR_MESSAGE = (
"Lists could be merged only if their elements are dictionaries"
)
INDEX_ERROR_MESSAGE = "A list element must have 'index' field to identify position of the element in the list"
INCONSISTENT_INDEXED_LIST_ERROR_MESSAGE = (
"All elements of a list must be either indexed or not indexed"
)
CANNOT_MERGE_NON_INDEXED_LISTS_ERROR_MESSAGE = (
"Cannot merge two non-indexed non-empty lists"
)
CANNOT_MERGE_NON_INDEXED_AND_INDEXED_LISTS_ERROR_MESSAGE = (
"Cannot merge a non-indexed list with an indexed list"
)
def show_json_path(path: Path) -> str:
ret = "$"
for elem in path:
if isinstance(elem, int):
ret += f"[{elem}]"
else:
ret += f".{elem}"
return ret
def merge_str(target: str, source: str, path: Path) -> str:
return target + source
def merge_int(target: int, source: int, path: Path) -> int:
return source
def merge_float(target: float, source: float, path: Path) -> float:
return source
def merge_bool(target: bool, source: bool, path: Path) -> bool:
return source
def merge_dicts(target: dict, source: dict, path: Path) -> dict:
for key, value in source.items():
path.append(key)
target[key] = merge_recursive(target.get(key), value, path)
path.pop()
return target
def is_indexed_list(xs: list) -> bool:
if len(xs) == 0:
return False
all_indexed = True
any_indexed = False
for elem in xs:
if isinstance(elem, dict) and "index" in elem:
any_indexed = True
else:
all_indexed = False
if any_indexed and not all_indexed:
raise AssertionError(INCONSISTENT_INDEXED_LIST_ERROR_MESSAGE)
return all_indexed
def merge_indexed_lists(target: list, source: list, path: Path) -> list:
for elem in source:
assert isinstance(elem, dict), LIST_OF_DICTS_ERROR_MESSAGE
index = elem.get("index")
assert isinstance(index, int), INDEX_ERROR_MESSAGE
path.append(index)
if index < len(target):
target[index] = merge_recursive(target[index], elem, path)
else:
target.extend([{"index": idx} for idx in range(len(target), index)])
target.append(copy.deepcopy(elem))
path.pop()
return target
def merge_lists(target: list, source: list, path: Path) -> list:
is_target_indexed = is_indexed_list(target)
is_source_indexed = is_indexed_list(source)
if len(source) == 0:
return target
if len(target) == 0:
if is_source_indexed:
return merge_indexed_lists(target, source, path)
else:
return copy.deepcopy(source)
if not is_target_indexed and not is_source_indexed:
raise AssertionError(CANNOT_MERGE_NON_INDEXED_LISTS_ERROR_MESSAGE)
assert (
is_target_indexed and is_source_indexed
), CANNOT_MERGE_NON_INDEXED_AND_INDEXED_LISTS_ERROR_MESSAGE
return merge_indexed_lists(target, source, path)
def merge_recursive(target: T, source: Any, path: Path) -> T:
"""
Recursively merging content of the source object into the target object.
The target object is modified in-place.
The source object is left unmodified.
"""
if source is None:
return target
if target is None:
if isinstance(source, dict):
target = cast(T, {})
elif isinstance(source, list):
target = cast(T, [])
else:
return source
if isinstance(target, list) and isinstance(source, list):
return merge_lists(target, source, path)
elif isinstance(target, dict) and isinstance(source, dict):
return merge_dicts(target, source, path)
elif isinstance(target, int) and isinstance(source, int):
return merge_int(target, source, path)
elif isinstance(target, float) and isinstance(source, float):
return merge_float(target, source, path)
elif isinstance(target, bool) and isinstance(source, bool):
return merge_bool(target, source, path)
elif isinstance(target, str) and isinstance(source, str):
return merge_str(target, source, path)
raise TypeError(
f"Cannot merge '{type(target).__name__}' with incoming '{type(source).__name__}' at path {show_json_path(path)}"
)
def merge(*chunks: T) -> T:
"""
Merge a list of chunks into one.
The very first chunk is modified in-place by accumulating the content of the subsequent chunks.
The subsequent chunks aren't modified.
The new content added to the first chunk is deeply copied from a source chunk.
"""
assert len(chunks) > 0, "At least one chunk must be provided"
ret: T = chunks[0]
for chunk in chunks[1:]:
ret = merge_recursive(ret, chunk, path=[])
return ret
def cleanup_indices(chunk: T) -> T:
"""
Recursively remove all "index" fields inside list elements in the given chunk.
The chunk is modified in-place.
"""
if isinstance(chunk, list):
ret = []
for elem in chunk:
if isinstance(elem, dict) and "index" in elem:
elem = elem.copy()
del elem["index"]
ret.append(cleanup_indices(elem))
return cast(T, ret)
if isinstance(chunk, dict):
return cast(
T, {key: cleanup_indices(value) for key, value in chunk.items()}
)
return chunk
_Chunk = TypeVar("_Chunk", bound=dict)
def merge_chat_completion_chunks(*chunks: _Chunk) -> _Chunk:
"""
The recursive merging procedure that avoids merging top-level atomic fields
(e.g. "id", "created", "model", "object", "system_fingerprint") and
instead chooses an _override_ merging strategy for such fields.
Non-atomic fields (e.g. "choice", "usage") are merged following
the standard recursive merging procedure.
The very first chunk is modified in-place.
The subsequent chunks are left unmodified.
"""
assert (
len(chunks) > 0
), "At least one chat completion chunk must be provided"
assert all(
isinstance(chunk, dict) for chunk in chunks
), "The chat completion chunks are expected to be dictionaries"
target, *sources = chunks
for chunk in sources:
source = cast(_Chunk, chunk.copy())
for key, value in list(source.items()):
if not isinstance(value, (list, dict)) and value is not None:
target[key] = value
del source[key]
target = merge(target, source)
return target