aidial_assistant/commands/run_plugin.py (90 lines of code) (raw):
from langchain.tools import APIOperation
from pydantic.main import BaseModel
from typing_extensions import override
from aidial_assistant.application.prompts import (
ADDON_BEST_EFFORT_TEMPLATE,
ADDON_SYSTEM_DIALOG_MESSAGE,
)
from aidial_assistant.chain.command_chain import (
CommandChain,
CommandConstructor,
)
from aidial_assistant.chain.history import History, ScopedMessage
from aidial_assistant.commands.base import (
Command,
ExecutionCallback,
ResultObject,
TextResult,
get_required_field,
)
from aidial_assistant.commands.open_api import OpenAPIChatCommand
from aidial_assistant.commands.plugin_callback import PluginChainCallback
from aidial_assistant.commands.reply import Reply
from aidial_assistant.model.model_client import (
ModelClient,
ReasonLengthException,
)
from aidial_assistant.open_api.operation_selector import collect_operations
from aidial_assistant.utils.open_ai import user_message
from aidial_assistant.utils.open_ai_plugin import OpenAIPluginInfo
class PluginInfo(BaseModel):
info: OpenAIPluginInfo
auth: str | None
class RunPlugin(Command):
def __init__(
self,
model_client: ModelClient,
plugin: PluginInfo,
max_completion_tokens: int,
):
self.model_client = model_client
self.plugin = plugin
self.max_completion_tokens = max_completion_tokens
@staticmethod
def token():
return "run-addon"
@override
async def execute(
self, args: dict[str, str], execution_callback: ExecutionCallback
) -> ResultObject:
query = get_required_field(args, "query")
return await self._run_plugin(query, execution_callback)
async def _run_plugin(
self, query: str, execution_callback: ExecutionCallback
) -> ResultObject:
info = self.plugin.info
ops = collect_operations(info.open_api, info.ai_plugin.api.url)
api_schema = "\n\n".join([op.to_typescript() for op in ops.values()]) # type: ignore
def create_command(op: APIOperation):
return lambda: OpenAPIChatCommand(op, self.plugin.auth)
command_dict: dict[str, CommandConstructor] = {}
for name, op in ops.items():
# The function is necessary to capture the current value of op.
# Otherwise, only first op will be used for all commands
command_dict[name] = create_command(op)
if Reply.token() in command_dict:
Exception(f"Operation with name '{Reply.token()}' is not allowed.")
command_dict[Reply.token()] = Reply
history = History(
assistant_system_message_template=ADDON_SYSTEM_DIALOG_MESSAGE.build(
command_names=ops.keys(),
api_description=info.ai_plugin.description_for_model,
api_schema=api_schema,
),
best_effort_template=ADDON_BEST_EFFORT_TEMPLATE.build(
api_schema=api_schema
),
scoped_messages=[
ScopedMessage(message=user_message(query), user_index=0)
],
)
chat = CommandChain(
model_client=self.model_client,
name="PLUGIN:" + self.plugin.info.ai_plugin.name_for_model,
command_dict=command_dict,
max_completion_tokens=self.max_completion_tokens,
)
callback = PluginChainCallback(execution_callback)
try:
await chat.run_chat(history, callback)
except ReasonLengthException:
pass
return TextResult(callback.result)