aidial_assistant/chain/command_chain.py (268 lines of code) (raw):
import json
import logging
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Tuple, cast
from openai import BadRequestError
from aidial_assistant.application.prompts import ENFORCE_JSON_FORMAT_TEMPLATE
from aidial_assistant.chain.callbacks.args_callback import ArgsCallback
from aidial_assistant.chain.callbacks.chain_callback import ChainCallback
from aidial_assistant.chain.callbacks.result_callback import ResultCallback
from aidial_assistant.chain.command_result import (
CommandInvocation,
CommandResult,
Status,
commands_to_text,
responses_to_text,
)
from aidial_assistant.chain.dialogue import Dialogue, DialogueTurn
from aidial_assistant.chain.history import History
from aidial_assistant.chain.model_response_reader import (
AssistantProtocolException,
CommandsReader,
skip_to_json_start,
)
from aidial_assistant.commands.base import (
Command,
CommandConstructor,
FinalCommand,
)
from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream
from aidial_assistant.json_stream.exceptions import JsonParsingException
from aidial_assistant.json_stream.json_object import JsonObject
from aidial_assistant.json_stream.json_parser import JsonParser, string_node
from aidial_assistant.json_stream.json_string import JsonString
from aidial_assistant.model.model_client import (
ChatCompletionMessageParam,
ModelClient,
)
from aidial_assistant.utils.stream import CumulativeStream
logger = logging.getLogger(__name__)
DEFAULT_MAX_RETRY_COUNT = 3
# Some relatively large number to avoid CxSAST warning about potential DoS attack.
# Later, the upper limit will be provided by the DIAL Core (proxy).
MAX_MODEL_COMPLETION_CHUNKS = 32000
CommandDict = dict[str, CommandConstructor]
class LimitExceededException(Exception):
pass
class ModelRequestLimiter(ABC):
@abstractmethod
async def verify_limit(self, messages: list[ChatCompletionMessageParam]):
pass
class CommandChain:
def __init__(
self,
name: str,
model_client: ModelClient,
command_dict: CommandDict,
max_completion_tokens: int | None = None,
max_retry_count: int = DEFAULT_MAX_RETRY_COUNT,
):
self.name = name
self.model_client = model_client
self.command_dict = command_dict
self.model_extra_args = (
{}
if max_completion_tokens is None
else {"max_tokens": max_completion_tokens}
)
self.max_retry_count = max_retry_count
def _log_message(self, role: str, content: str | None):
logger.debug(f"[{self.name}] {role}: {content or ''}")
def _log_messages(self, messages: list[ChatCompletionMessageParam]):
if logger.isEnabledFor(logging.DEBUG):
for message in messages:
self._log_message(message["role"], message.get("content"))
async def run_chat(
self,
history: History,
callback: ChainCallback,
model_request_limiter: ModelRequestLimiter | None = None,
):
dialogue = Dialogue()
try:
messages = history.to_protocol_messages()
while True:
dialogue_turn = await self._run_with_protocol_failure_retries(
callback,
messages + dialogue.messages,
model_request_limiter,
)
if dialogue_turn is None:
break
dialogue.append(dialogue_turn)
except (JsonParsingException, AssistantProtocolException):
messages = (
history.to_best_effort_messages(
"The next constructed API request is incorrect.",
dialogue,
)
if not dialogue.is_empty()
else history.to_user_messages()
)
await self._generate_result(messages, callback)
except (BadRequestError, LimitExceededException) as e:
if dialogue.is_empty() or (
isinstance(e, BadRequestError) and e.code == "429"
):
raise
# Assuming the context length is exceeded
dialogue.pop()
# TODO: Limit the error message size. The error message should not exceed reserved assistant overheads.
await self._generate_result(
history.to_best_effort_messages(str(e), dialogue), callback
)
async def _run_with_protocol_failure_retries(
self,
callback: ChainCallback,
messages: list[ChatCompletionMessageParam],
model_request_limiter: ModelRequestLimiter | None = None,
) -> DialogueTurn | None:
last_error: Exception | None = None
try:
self._log_messages(messages)
retries = Dialogue()
while True:
all_messages = self._reinforce_json_format(
messages + retries.messages
)
if model_request_limiter:
await model_request_limiter.verify_limit(all_messages)
chunk_stream = CumulativeStream(
self.model_client.agenerate(
all_messages, **self.model_extra_args # type: ignore
)
)
try:
commands, responses = await self._run_commands(
chunk_stream, callback
)
if responses:
request_text = commands_to_text(commands)
response_text = responses_to_text(responses)
callback.on_state(request_text, response_text)
return DialogueTurn(
assistant_message=request_text,
user_message=response_text,
)
break
except (JsonParsingException, AssistantProtocolException) as e:
logger.exception("Failed to process model response")
retry_count = retries.dialogue_turn_count()
callback.on_error(
"Error"
if retry_count == 0
else f"Error (retry {retry_count})",
"The model failed to construct addon request.",
)
if retry_count >= self.max_retry_count:
raise
last_error = e
retries.append(
DialogueTurn(
assistant_message=chunk_stream.buffer,
user_message="Failed to parse JSON commands: "
+ str(e),
)
)
finally:
self._log_message("assistant", chunk_stream.buffer)
except (BadRequestError, LimitExceededException) as e:
if last_error:
# Retries can increase the prompt size, which may lead to token overflow.
# Thus, if the original error was a protocol error, it should be thrown instead.
raise last_error
callback.on_error("Error", str(e))
raise
async def _run_commands(
self, chunk_stream: AsyncIterator[str], callback: ChainCallback
) -> Tuple[list[CommandInvocation], list[CommandResult]]:
char_stream = ChunkedCharStream(chunk_stream)
await skip_to_json_start(char_stream)
root_node = await JsonParser().parse(char_stream)
commands: list[CommandInvocation] = []
responses: list[CommandResult] = []
request_reader = CommandsReader(root_node)
async for invocation in request_reader.parse_invocations():
command_name = await invocation.parse_name()
command = self._create_command(command_name)
args = await invocation.parse_args()
if isinstance(command, FinalCommand):
if len(responses) > 0:
continue
message = string_node(await args.get("message"))
await CommandChain._to_result(
message
if isinstance(message, JsonString)
else message.to_chunks(),
callback.result_callback(),
)
break
else:
response = await CommandChain._execute_command(
command_name, command, args, callback
)
commands.append(
cast(CommandInvocation, invocation.node.value())
)
responses.append(response)
return commands, responses
def _create_command(self, name: str) -> Command:
if name not in self.command_dict:
raise AssistantProtocolException(
f"The command '{name}' is expected to be one of {list(self.command_dict.keys())}"
)
return self.command_dict[name]()
async def _generate_result(
self,
messages: list[ChatCompletionMessageParam],
callback: ChainCallback,
):
stream = self.model_client.agenerate(messages)
await CommandChain._to_result(stream, callback.result_callback())
@staticmethod
def _reinforce_json_format(
messages: list[ChatCompletionMessageParam],
) -> list[ChatCompletionMessageParam]:
last_message = messages[-1].copy()
last_message["content"] = ENFORCE_JSON_FORMAT_TEMPLATE.render(
response=last_message.get("content", "")
)
return messages[:-1] + [last_message]
@staticmethod
async def _to_args(
args: JsonObject, args_callback: ArgsCallback
) -> dict[str, Any]:
args_callback.on_args_start()
result = ""
async for chunk in args.to_chunks():
args_callback.on_args_chunk(chunk)
result += chunk
parsed_args = json.loads(result)
args_callback.on_args_end()
return parsed_args
@staticmethod
async def _to_result(stream: AsyncIterator[str], callback: ResultCallback):
try:
for _ in range(MAX_MODEL_COMPLETION_CHUNKS):
chunk = await anext(stream)
callback.on_result(chunk)
logger.warning(
f"Max chunk count of {MAX_MODEL_COMPLETION_CHUNKS} exceeded in the reply"
)
except StopAsyncIteration:
pass
@staticmethod
async def _execute_command(
name: str,
command: Command,
args: JsonObject,
chain_callback: ChainCallback,
) -> CommandResult:
try:
with chain_callback.command_callback() as command_callback:
command_callback.on_command(name)
response = await command.execute(
await CommandChain._to_args(
args, command_callback.args_callback()
),
command_callback.execution_callback(),
)
command_callback.on_result(response)
return {"status": Status.SUCCESS, "response": response.text}
except Exception as e:
logger.exception(f"Failed to execute command {name}")
return {"status": Status.ERROR, "response": str(e)}