aidial_assistant/application/assistant_application.py (269 lines of code) (raw):
import logging
from pathlib import Path
from typing import Tuple
from aidial_sdk.chat_completion import FinishReason
from aidial_sdk.chat_completion.base import ChatCompletion
from aidial_sdk.chat_completion.request import Addon, Message, Request, Role
from aidial_sdk.chat_completion.response import Response
from openai.lib.azure import AsyncAzureOpenAI
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel
from aidial_assistant.application.addons_dialogue_limiter import (
AddonsDialogueLimiter,
)
from aidial_assistant.application.args import parse_args
from aidial_assistant.application.assistant_callback import (
AssistantChainCallback,
)
from aidial_assistant.application.prompts import (
MAIN_BEST_EFFORT_TEMPLATE,
MAIN_SYSTEM_DIALOG_MESSAGE,
)
from aidial_assistant.chain.command_chain import (
CommandChain,
CommandConstructor,
CommandDict,
)
from aidial_assistant.chain.history import History, ScopedMessage
from aidial_assistant.commands.reply import Reply
from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin
from aidial_assistant.commands.run_tool import RunTool
from aidial_assistant.model.model_client import (
ModelClient,
ReasonLengthException,
)
from aidial_assistant.tools_chain.tools_chain import (
CommandToolDict,
ToolsChain,
convert_commands_to_tools,
)
from aidial_assistant.utils.exceptions import (
RequestParameterValidationError,
unhandled_exception_handler,
)
from aidial_assistant.utils.open_ai import construct_tool
from aidial_assistant.utils.open_ai_plugin import (
AddonTokenSource,
get_open_ai_plugin_info,
get_plugin_auth,
)
from aidial_assistant.utils.state import State, parse_history
logger = logging.getLogger(__name__)
class AddonReference(BaseModel):
name: str | None
url: str
def _get_request_args(request: Request) -> dict[str, str]:
args = {
"model": request.model,
"temperature": request.temperature,
"user": request.user,
}
return {k: v for k, v in args.items() if v is not None}
def _validate_addons(addons: list[Addon] | None) -> list[AddonReference]:
addon_references: list[AddonReference] = []
for index, addon in enumerate(addons or []):
if addon.url is None:
raise RequestParameterValidationError(
f"Missing required addon url at index {index}.",
param="addons",
)
addon_references.append(AddonReference(name=addon.name, url=addon.url))
return addon_references
def _validate_messages(messages: list[Message]) -> None:
if not messages:
raise RequestParameterValidationError(
"Message list cannot be empty.", param="messages"
)
if messages[-1].role != Role.USER:
raise RequestParameterValidationError(
"Last message must be from the user.", param="messages"
)
def _construct_tool(name: str, description: str) -> ChatCompletionToolParam:
return construct_tool(
name,
description,
{
"query": {
"type": "string",
"description": "A task written in natural language",
}
},
["query"],
)
def _create_history(
messages: list[ScopedMessage], plugins: list[PluginInfo]
) -> History:
plugin_descriptions = {
plugin.info.ai_plugin.name_for_model: plugin.info.open_api.info.description
or plugin.info.ai_plugin.description_for_human
for plugin in plugins
}
return History(
assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build(
addons=plugin_descriptions
),
best_effort_template=MAIN_BEST_EFFORT_TEMPLATE.build(
addons=plugin_descriptions
),
scoped_messages=messages,
)
class AssistantApplication(ChatCompletion):
def __init__(
self, config_dir: Path, tools_supporting_deployments: set[str]
):
self.args = parse_args(config_dir)
self.tools_supporting_deployments = tools_supporting_deployments
@unhandled_exception_handler
async def chat_completion(
self, request: Request, response: Response
) -> None:
_validate_messages(request.messages)
addon_references = _validate_addons(request.addons)
chat_args = _get_request_args(request)
model = ModelClient(
client=AsyncAzureOpenAI(
azure_endpoint=self.args.openai_conf.api_base,
api_key=request.api_key,
# 2023-12-01-preview is needed to support tools
api_version="2023-12-01-preview",
),
model_args=chat_args,
)
token_source = AddonTokenSource(
request.headers,
(addon_reference.url for addon_reference in addon_references),
)
plugins: list[PluginInfo] = []
# DIAL Core has own names for addons, so in stages we need to map them to the names used by the user
addon_name_mapping: dict[str, str] = {}
for addon_reference in addon_references:
info = await get_open_ai_plugin_info(addon_reference.url)
plugins.append(
PluginInfo(
info=info,
auth=get_plugin_auth(
info.ai_plugin.auth.type,
info.ai_plugin.auth.authorization_type,
addon_reference.url,
token_source,
),
)
)
if addon_reference.name:
addon_name_mapping[
info.ai_plugin.name_for_model
] = addon_reference.name
if request.model in self.tools_supporting_deployments:
await AssistantApplication._run_native_tools_chat(
model, plugins, addon_name_mapping, request, response
)
else:
await AssistantApplication._run_emulated_tools_chat(
model, plugins, addon_name_mapping, request, response
)
@staticmethod
async def _run_emulated_tools_chat(
model: ModelClient,
addons: list[PluginInfo],
addon_name_mapping: dict[str, str],
request: Request,
response: Response,
):
# TODO: Add max_addons_dialogue_tokens as a request parameter
max_addons_dialogue_tokens = 1000
def create_command(addon: PluginInfo):
return lambda: RunPlugin(model, addon, max_addons_dialogue_tokens)
command_dict: CommandDict = {
addon.info.ai_plugin.name_for_model: create_command(addon)
for addon in addons
}
if Reply.token() in command_dict:
RequestParameterValidationError(
f"Addon with name '{Reply.token()}' is not allowed for model {request.model}.",
param="addons",
)
command_dict[Reply.token()] = Reply
chain = CommandChain(
model_client=model, name="ASSISTANT", command_dict=command_dict
)
addon_descriptions = {
addon.info.ai_plugin.name_for_model: addon.info.open_api.info.description
or addon.info.ai_plugin.description_for_human
for addon in addons
}
scoped_messages = parse_history(request.messages)
history = History(
assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build(
addons=addon_descriptions
),
best_effort_template=MAIN_BEST_EFFORT_TEMPLATE.build(
addons=addon_descriptions
),
scoped_messages=scoped_messages,
)
discarded_user_messages: set[int] | None = None
if request.max_prompt_tokens is not None:
history, discarded_messages = await history.truncate(
model, request.max_prompt_tokens
)
discarded_user_messages = set(
scoped_messages[index].user_index
for index in discarded_messages
)
# TODO: else compare the history size to the max prompt tokens of the underlying model
choice = response.create_single_choice()
choice.open()
callback = AssistantChainCallback(choice, addon_name_mapping)
finish_reason = FinishReason.STOP
try:
model_request_limiter = AddonsDialogueLimiter(
max_addons_dialogue_tokens, model
)
await chain.run_chat(history, callback, model_request_limiter)
except ReasonLengthException:
finish_reason = FinishReason.LENGTH
if callback.invocations:
choice.set_state(State(invocations=callback.invocations))
choice.close(finish_reason)
response.set_usage(
model.total_prompt_tokens, model.total_completion_tokens
)
if discarded_user_messages is not None:
response.set_discarded_messages(list(discarded_user_messages))
@staticmethod
async def _run_native_tools_chat(
model: ModelClient,
plugins: list[PluginInfo],
addon_name_mapping: dict[str, str],
request: Request,
response: Response,
):
# TODO: Add max_addons_dialogue_tokens as a request parameter
max_addons_dialogue_tokens = 1000
def create_command_tool(
plugin: PluginInfo,
) -> Tuple[CommandConstructor, ChatCompletionToolParam]:
return lambda: RunTool(
model, plugin, max_addons_dialogue_tokens
), _construct_tool(
plugin.info.ai_plugin.name_for_model,
plugin.info.ai_plugin.description_for_human,
)
commands: CommandToolDict = {
plugin.info.ai_plugin.name_for_model: create_command_tool(plugin)
for plugin in plugins
}
chain = ToolsChain(model, commands)
choice = response.create_single_choice()
choice.open()
callback = AssistantChainCallback(choice, addon_name_mapping)
finish_reason = FinishReason.STOP
messages = convert_commands_to_tools(parse_history(request.messages))
try:
model_request_limiter = AddonsDialogueLimiter(
max_addons_dialogue_tokens, model
)
await chain.run_chat(messages, callback, model_request_limiter)
except ReasonLengthException:
finish_reason = FinishReason.LENGTH
if callback.invocations:
choice.set_state(State(invocations=callback.invocations))
choice.close(finish_reason)
response.set_usage(
model.total_prompt_tokens, model.total_completion_tokens
)