aidial_adapter_bedrock/llm/truncate_prompt.py (121 lines of code) (raw):

from abc import ABC, abstractmethod from typing import Awaitable, Callable, List, Optional, Set, Tuple, TypeVar from aidial_sdk.exceptions import ContextLengthExceededError from aidial_sdk.exceptions import HTTPException as DialException from aidial_sdk.exceptions import ( InvalidRequestError, TruncatePromptSystemAndLastUserError, ) from pydantic import BaseModel from aidial_adapter_bedrock.utils.list import omit_by_indices, select_by_indices class TruncatePromptError(ABC, BaseModel): @abstractmethod def to_dial_exception(self) -> DialException: pass def print(self) -> str: return self.to_dial_exception().message class InconsistentLimitsError(TruncatePromptError): user_limit: int model_limit: int def to_dial_exception(self) -> DialException: return InvalidRequestError( f"The request maximum prompt tokens is {self.user_limit}. " f"However, the model's maximum context length is {self.model_limit} tokens." ) class ModelLimitOverflow(TruncatePromptError): model_limit: int token_count: int def to_dial_exception(self) -> DialException: return ContextLengthExceededError(self.model_limit, self.token_count) class UserLimitOverflow(TruncatePromptError): user_limit: int token_count: int def to_dial_exception(self) -> DialException: return TruncatePromptSystemAndLastUserError( self.user_limit, self.token_count ) def _partition_indexer(chunks: List[int]) -> Callable[[int], List[int]]: """ Returns a function that maps an index to indices of its partition. """ mapping: dict[int, List[int]] = {} offset = 0 for size in chunks: chunk = list(range(offset, offset + size)) for idx in range(size): mapping[offset + idx] = chunk offset += size return mapping.__getitem__ _T = TypeVar("_T") DiscardedMessages = List[int] async def truncate_prompt( messages: List[_T], tokenizer: Callable[[List[_T]], Awaitable[int]], keep_message: Callable[[List[_T], int], bool], partitioner: Callable[[List[_T]], List[int]], model_limit: Optional[int], user_limit: Optional[int], ) -> Tuple[DiscardedMessages, List[_T]]: """ Returns a list of indices of discarded messages and a list of preserved messages """ result = await compute_discarded_messages( messages, tokenizer, keep_message, partitioner, model_limit, user_limit, ) if isinstance(result, TruncatePromptError): raise result.to_dial_exception() return (list(result), omit_by_indices(messages, result)) async def compute_discarded_messages( messages: List[_T], tokenizer: Callable[[List[_T]], Awaitable[int]], keep_message: Callable[[List[_T], int], bool], partitioner: Callable[[List[_T]], List[int]], model_limit: Optional[int], user_limit: Optional[int], ) -> DiscardedMessages | TruncatePromptError: if ( user_limit is not None and model_limit is not None and user_limit > model_limit ): return InconsistentLimitsError( user_limit=user_limit, model_limit=model_limit ) if user_limit is None: if model_limit is None: return [] token_count = await tokenizer(messages) if token_count <= model_limit: return [] return ModelLimitOverflow( model_limit=model_limit, token_count=token_count ) partition_sizes = partitioner(messages) if sum(partition_sizes) != len(messages): raise ValueError( "Partition sizes must add up to the number of messages." ) async def _tokenize_selected(indices: Set[int]) -> int: return await tokenizer(select_by_indices(messages, indices)) get_partition_indices = _partition_indexer(partition_sizes) n = len(messages) kept_indices: Set[int] = { j for i in range(n) for j in get_partition_indices(i) if keep_message(messages, i) } token_count = await _tokenize_selected(kept_indices) if token_count > user_limit: return UserLimitOverflow(user_limit=user_limit, token_count=token_count) for idx in reversed(range(n)): if idx in kept_indices: continue chunk_indices = get_partition_indices(idx) new_token_count = await _tokenize_selected( {*kept_indices, *chunk_indices} ) if new_token_count > user_limit: break kept_indices.update(chunk_indices) all_indices = set(range(n)) return sorted(list(all_indices - kept_indices))