aidial_adapter_openai/utils/truncate_prompt.py (42 lines of code) (raw):

from typing import Callable, List, Set, Tuple, TypeVar from aidial_sdk.exceptions import ( TruncatePromptSystemAndLastUserError, TruncatePromptSystemError, ) _T = TypeVar("_T") DiscardedMessages = List[int] TruncatedTokens = int def truncate_prompt( messages: List[_T], message_tokens: Callable[[_T], int], is_system_message: Callable[[_T], bool], max_prompt_tokens: int, initial_prompt_tokens: int, ) -> Tuple[List[_T], DiscardedMessages, TruncatedTokens]: prompt_tokens = initial_prompt_tokens system_messages_count = 0 kept_messages: Set[int] = set() # Count system messages first for idx, message_holder in enumerate(messages): if is_system_message(message_holder): kept_messages.add(idx) system_messages_count += 1 prompt_tokens += message_tokens(message_holder) if max_prompt_tokens < prompt_tokens: raise TruncatePromptSystemError(max_prompt_tokens, prompt_tokens) # Then non-system messages in the reverse order for idx, message_holder in reversed(list(enumerate(messages))): if is_system_message(message_holder): continue calculated_message_tokens = message_tokens(message_holder) if max_prompt_tokens < prompt_tokens + calculated_message_tokens: if len(kept_messages) == system_messages_count: raise TruncatePromptSystemAndLastUserError( max_prompt_tokens, prompt_tokens + calculated_message_tokens ) break prompt_tokens += calculated_message_tokens kept_messages.add(idx) new_messages = [ message for idx, message in enumerate(messages) if idx in kept_messages ] discarded_messages = list(set(range(len(messages))) - kept_messages) return new_messages, discarded_messages, prompt_tokens