aidial_interceptors_sdk/chat_completion/helpers.py (122 lines of code) (raw):

from typing import Awaitable, Callable, List, TypeVar, overload from aidial_interceptors_sdk.utils.not_given import NOT_GIVEN, NotGiven T = TypeVar("T") P = TypeVar("P") @overload async def traverse_dict_value( path: P, d: dict, key: str, on_value: Callable[ [P, T | NotGiven | None], Awaitable[T | NotGiven | None], ], ) -> dict: ... @overload async def traverse_dict_value( path: P, d: NotGiven, key: str, on_value: Callable[ [P, T | NotGiven | None], Awaitable[T | NotGiven | None], ], ) -> NotGiven: ... @overload async def traverse_dict_value( path: P, d: None, key: str, on_value: Callable[ [P, T | NotGiven | None], Awaitable[T | NotGiven | None], ], ) -> None: ... async def traverse_dict_value( path: P, d: dict | NotGiven | None, key: str, on_value: Callable[ [P, T | NotGiven | None], Awaitable[T | NotGiven | None], ], ) -> dict | NotGiven | None: if d is None or isinstance(d, NotGiven): return d old_value = d.get(key, NOT_GIVEN) new_value = await on_value(path, old_value) if new_value is NOT_GIVEN: if old_value is NOT_GIVEN: return d else: return {k: v for k, v in d.items() if k != key} else: return {**d, key: new_value} @overload async def traverse_required_dict_value( path: P, d: None, key: str, on_value: Callable[[P, T], Awaitable[T]], ) -> None: ... @overload async def traverse_required_dict_value( path: P, d: NotGiven, key: str, on_value: Callable[[P, T], Awaitable[T]], ) -> NotGiven: ... @overload async def traverse_required_dict_value( path: P, d: dict, key: str, on_value: Callable[[P, T], Awaitable[T]], ) -> dict: ... async def traverse_required_dict_value( path: P, d: dict | NotGiven | None, key: str, on_value: Callable[[P, T], Awaitable[T]], ) -> dict | NotGiven | None: if d is None or isinstance(d, NotGiven): return d old_value = d.get(key) if old_value is None: raise ValueError(f"Missing required key {key!r} in a dictionary") new_value = await on_value(path, old_value) return {**d, key: new_value} @overload async def traverse_list( create_elem_path: Callable[[int], P], lst: NotGiven, on_elem: Callable[[P, T], Awaitable[List[T] | T]], ) -> NotGiven: ... @overload async def traverse_list( create_elem_path: Callable[[int], P], lst: None, on_elem: Callable[[P, T], Awaitable[List[T] | T]], ) -> None: ... @overload async def traverse_list( create_elem_path: Callable[[int], P], lst: List[T], on_elem: Callable[[P, T], Awaitable[List[T] | T]], ) -> List[T]: ... async def traverse_list( create_elem_path: Callable[[int], P], lst: List[T] | NotGiven | None, on_elem: Callable[[P, T], Awaitable[List[T] | T]], ) -> List[T] | NotGiven | None: if lst is None or isinstance(lst, NotGiven): return lst ret: List[T] = [] for idx, elem in enumerate(lst): idx = elem.get("index", idx) if isinstance(elem, dict) else idx elem = await on_elem(create_elem_path(idx), elem) if isinstance(elem, list): ret.extend(elem) else: ret.append(elem) return ret