aidial_interceptors_sdk/chat_completion/index_mapper.py (23 lines of code) (raw):

from typing import Dict, Generic, Hashable, Set, TypeVar from aidial_sdk.pydantic_v1 import BaseModel _Index = TypeVar("_Index", bound=Hashable) class IndexMapper(BaseModel, Generic[_Index]): """ Used to maintain consistent mapping between indexed values in the incoming and outgoing streams, given that outgoing stream may include additional elements at fixed indices. """ migrated: Dict[_Index, int] = {} used_indices: Set[int] = set() fresh_index: int = 0 def reserve(self, index: int | None = None) -> int: if index is None: return self._get_fresh_index() if index in self.used_indices: raise ValueError(f"Index {index} is already taken") self.used_indices.add(index) return index def __call__(self, index: _Index) -> int: if index not in self.migrated: self.migrated[index] = self._get_fresh_index() return self.migrated[index] def _get_fresh_index(self) -> int: while self.fresh_index in self.used_indices: self.fresh_index += 1 self.used_indices.add(self.fresh_index) return self.fresh_index