aidial_adapter_vertexai/chat/truncate_prompt.py (148 lines of code) (raw):
from abc import ABC, abstractmethod
from typing import (
Awaitable,
Callable,
Generic,
List,
Optional,
Self,
Set,
Sized,
TypeVar,
)
from aidial_sdk.exceptions import ContextLengthExceededError
from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.exceptions import (
InvalidRequestError,
TruncatePromptSystemAndLastUserError,
)
from pydantic.v1 import BaseModel
DiscardedMessages = List[int]
_P = TypeVar("_P")
class TruncatedPrompt(BaseModel, Generic[_P]):
prompt: _P
discarded_messages: DiscardedMessages
class TruncatePromptError(ABC, BaseModel):
@abstractmethod
def to_dial_exception(self) -> DialException:
pass
def print(self) -> str:
return self.to_dial_exception().message
class InconsistentLimitsError(TruncatePromptError):
user_limit: int
model_limit: int
def to_dial_exception(self) -> DialException:
return InvalidRequestError(
f"The request maximum prompt tokens is {self.user_limit}. "
f"However, the model's maximum context length is {self.model_limit} tokens."
)
class ModelLimitOverflowError(TruncatePromptError):
model_limit: int
token_count: int
def to_dial_exception(self) -> DialException:
return ContextLengthExceededError(self.model_limit, self.token_count)
class UserLimitOverflowError(TruncatePromptError):
user_limit: int
token_count: int
def to_dial_exception(self) -> DialException:
return TruncatePromptSystemAndLastUserError(
self.user_limit, self.token_count
)
def _partition_indexer(chunks: List[int]) -> Callable[[int], List[int]]:
"""Returns a function that maps an index to indices of its partition.
>>> [_partition_indexer([2, 3])(i) for i in range(5)]
[[0, 1], [0, 1], [2, 3, 4], [2, 3, 4], [2, 3, 4]]
"""
mapping: dict[int, List[int]] = {}
offset = 0
for size in chunks:
chunk = list(range(offset, offset + size))
for idx in range(size):
mapping[offset + idx] = chunk
offset += size
return mapping.__getitem__
class TruncatablePrompt(ABC, Sized):
@abstractmethod
def is_required_message(self, index: int) -> bool:
"""
Returns True if the message at the given index is required,
meaning that the prompt truncation algorithm should not
remove such a message.
Typically it's all system messages and the last message in the conversation.
"""
...
@abstractmethod
def partition_messages(self) -> List[int]:
"""
Returns a list of sizes of contiguous non-overlapping sequences of messages
which together represent partition of the list of messages.
Therefore, the sum of the sizes must be equal to the number of messages.
Each partition is either preserved or discarded by the prompt truncation algorithm *as a whole*.
Typical partition are:
* the trivial partition that turns each message into a single-element block of messages
* the turn-based partition that consolidates user-bot turns into two-element blocks of messages.
This is useful for models that require the conversation to be composed of
a whole number of user-bot turns followed by a user query.
"""
...
@abstractmethod
def select(self, indices: Set[int]) -> Self:
"""
Return a new prompt composed of the messages with the given indices.
"""
...
def omit(self, indices: Set[int]) -> Self:
return self.select(set(range(len(self))) - indices)
async def truncate(
self,
*,
tokenizer: Callable[[Self], Awaitable[int]],
model_limit: Optional[int] = None,
user_limit: Optional[int] = None,
) -> TruncatedPrompt[Self]:
"""
Returns a list of indices of discarded messages and
the truncated prompt that doesn't include the discarded messages and fits into the given user limit.
The list of discarded messages is a prefix of the list of non-required messages and
its length is as minimal as possible to fit the truncated prompt into the given user limit.
Parameters:
* The tokenizer computes number of tokens in the given prompt.
* The model limit is the intrinsic context limit on the number of input tokes for the given model.
* The user limit (aka max_prompt_tokens) defines the number of tokens that the resulting truncated prompt must fit in.
Throws a DIAL exception when the truncation satisfying the given limits is impossible.
"""
result = await self.compute_discarded_messages(
tokenizer=tokenizer,
model_limit=model_limit,
user_limit=user_limit,
)
if isinstance(result, TruncatePromptError):
raise result.to_dial_exception()
return TruncatedPrompt(
discarded_messages=list(result),
prompt=self.omit(set(result)),
)
async def compute_discarded_messages(
self,
*,
tokenizer: Callable[[Self], Awaitable[int]],
model_limit: Optional[int],
user_limit: Optional[int],
) -> DiscardedMessages | TruncatePromptError:
if (
user_limit is not None
and model_limit is not None
and user_limit > model_limit
):
return InconsistentLimitsError(
user_limit=user_limit, model_limit=model_limit
)
if user_limit is None:
if model_limit is None:
return []
token_count = await tokenizer(self)
if token_count <= model_limit:
return []
return ModelLimitOverflowError(
model_limit=model_limit, token_count=token_count
)
if await tokenizer(self) <= user_limit:
return []
partition_sizes = self.partition_messages()
if sum(partition_sizes) != len(self):
raise ValueError(
"Partition sizes must add up to the number of messages."
)
async def _tokenize_selected(indices: Set[int]) -> int:
return await tokenizer(self.select(indices))
get_partition_indices = _partition_indexer(partition_sizes)
n = len(self)
kept_indices: Set[int] = {
j
for i in range(n)
for j in get_partition_indices(i)
if self.is_required_message(i)
}
token_count = await _tokenize_selected(kept_indices)
if token_count > user_limit:
return UserLimitOverflowError(
user_limit=user_limit, token_count=token_count
)
for idx in reversed(range(n)):
if idx in kept_indices:
continue
chunk_indices = get_partition_indices(idx)
new_kept_indices = {*kept_indices, *chunk_indices}
if (
len(new_kept_indices) == n
or await _tokenize_selected(new_kept_indices) > user_limit
):
break
kept_indices = new_kept_indices
all_indices = set(range(n))
return sorted(list(all_indices - kept_indices))