aidial_assistant/chain/history.py (140 lines of code) (raw):
from enum import Enum
from typing import Tuple, cast
from jinja2 import Template
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
)
from pydantic import BaseModel
from aidial_assistant.chain.command_result import (
CommandInvocation,
commands_to_text,
)
from aidial_assistant.chain.dialogue import Dialogue
from aidial_assistant.commands.reply import Reply
from aidial_assistant.model.model_client import ModelClient
from aidial_assistant.utils.open_ai import assistant_message, system_message
class ContextLengthExceeded(Exception):
pass
class MessageScope(str, Enum):
INTERNAL = "internal" # internal dialog with plugins/addons, not visible to the user on the top level
USER = "user" # top-level dialog with the user
class ScopedMessage(BaseModel):
scope: MessageScope = MessageScope.USER
message: ChatCompletionMessageParam
user_index: int
class History:
def __init__(
self,
assistant_system_message_template: Template,
best_effort_template: Template,
scoped_messages: list[ScopedMessage],
):
self.assistant_system_message_template = (
assistant_system_message_template
)
self.best_effort_template = best_effort_template
self.scoped_messages = scoped_messages
def to_protocol_messages(self) -> list[ChatCompletionMessageParam]:
messages: list[ChatCompletionMessageParam] = []
scoped_message_iterator = iter(self.scoped_messages)
if self._is_first_system_message():
message = cast(
ChatCompletionSystemMessageParam,
next(scoped_message_iterator).message,
)
messages.append(
system_message(
self.assistant_system_message_template.render(
system_prefix=message["content"]
)
)
)
else:
messages.append(
system_message(self.assistant_system_message_template.render())
)
for scoped_message in scoped_message_iterator:
message = scoped_message.message
scope = scoped_message.scope
if scope == MessageScope.USER and message["role"] == "assistant":
# Clients see replies in plain text, but the model should understand how to reply appropriately.
content = commands_to_text(
[
CommandInvocation(
command=Reply.token(),
arguments={"message": message.get("content", "")},
)
]
)
messages.append(assistant_message(content))
else:
messages.append(message)
return messages
def to_user_messages(self) -> list[ChatCompletionMessageParam]:
return [
scoped_message.message
for scoped_message in self.scoped_messages
if scoped_message.scope == MessageScope.USER
]
def to_best_effort_messages(
self, error: str, dialogue: Dialogue
) -> list[ChatCompletionMessageParam]:
messages = self.to_user_messages()
last_message = messages[-1].copy()
last_message["content"] = self.best_effort_template.render(
message=last_message.get("content", ""),
error=error,
dialogue=dialogue.messages,
)
messages[-1] = last_message
return messages
async def truncate(
self, model_client: ModelClient, max_prompt_tokens: int
) -> Tuple["History", list[int]]:
discarded_messages = await self._get_discarded_messages(
model_client, max_prompt_tokens
)
if not discarded_messages:
return self, []
discarded_messages_set = set(discarded_messages)
return (
History(
assistant_system_message_template=self.assistant_system_message_template,
best_effort_template=self.best_effort_template,
scoped_messages=[
scoped_message
for index, scoped_message in enumerate(self.scoped_messages)
if index not in discarded_messages_set
],
),
discarded_messages,
)
async def _get_discarded_messages(
self, model_client: ModelClient, max_prompt_tokens: int
) -> list[int]:
discarded_protocol_messages = await model_client.get_discarded_messages(
self.to_protocol_messages(),
max_prompt_tokens,
)
if discarded_protocol_messages:
discarded_protocol_messages.sort()
discarded_messages = (
discarded_protocol_messages
if self._is_first_system_message()
else [index - 1 for index in discarded_protocol_messages]
)
user_indices = set(
self.scoped_messages[index].user_index
for index in discarded_messages
)
return [
index
for index, scoped_message in enumerate(self.scoped_messages)
if scoped_message.user_index in user_indices
]
return discarded_protocol_messages
def _is_first_system_message(self) -> bool:
return (
len(self.scoped_messages) > 0
and self.scoped_messages[0].message["role"] == "system"
)