from typing import Any

from langchain_community.tools.openapi.utils.api_models import (
    APIOperation,
    APIPropertyBase,
)
from openai.types.chat import ChatCompletionToolParam
from typing_extensions import override

from aidial_assistant.commands.base import (
    Command,
    ExecutionCallback,
    ResultObject,
    TextResult,
    get_required_field,
)
from aidial_assistant.commands.open_api import OpenAPIChatCommand
from aidial_assistant.commands.plugin_callback import PluginChainCallback
from aidial_assistant.commands.run_plugin import PluginInfo
from aidial_assistant.model.model_client import (
    ModelClient,
    ReasonLengthException,
)
from aidial_assistant.open_api.operation_selector import collect_operations
from aidial_assistant.tools_chain.tools_chain import (
    CommandTool,
    CommandToolDict,
    ToolsChain,
)
from aidial_assistant.utils.open_ai import (
    construct_tool,
    system_message,
    user_message,
)


def _construct_property(p: APIPropertyBase) -> dict[str, Any]:
    parameter = {
        "type": p.type,
        "description": p.description,
    }
    return {k: v for k, v in parameter.items() if v is not None}


def _construct_tool(op: APIOperation) -> ChatCompletionToolParam:
    properties = {}
    required = []
    for p in op.properties:
        properties[p.name] = _construct_property(p)

        if p.required:
            required.append(p.name)

    if op.request_body is not None:
        for p in op.request_body.properties:
            properties[p.name] = _construct_property(p)

            if p.required:
                required.append(p.name)

    return construct_tool(
        op.operation_id, op.description or "", properties, required
    )


class RunTool(Command):
    def __init__(
        self, model: ModelClient, plugin: PluginInfo, max_completion_tokens: int
    ):
        self.model = model
        self.plugin = plugin
        self.max_completion_tokens = max_completion_tokens

    @staticmethod
    def token():
        return "run-tool"

    @override
    async def execute(
        self, args: dict[str, Any], execution_callback: ExecutionCallback
    ) -> ResultObject:
        query = get_required_field(args, "query")

        ops = collect_operations(
            self.plugin.info.open_api, self.plugin.info.ai_plugin.api.url
        )

        def create_command_tool(op: APIOperation) -> CommandTool:
            return lambda: OpenAPIChatCommand(
                op, self.plugin.auth
            ), _construct_tool(op)

        commands: CommandToolDict = {
            name: create_command_tool(op) for name, op in ops.items()
        }

        chain = ToolsChain(self.model, commands, self.max_completion_tokens)

        messages = [
            system_message(self.plugin.info.ai_plugin.description_for_model),
            user_message(query),
        ]
        chain_callback = PluginChainCallback(execution_callback)
        try:
            await chain.run_chat(messages, chain_callback)
        except ReasonLengthException:
            pass

        return TextResult(chain_callback.result)
