import asyncio
from typing import List, assert_never

from aidial_sdk.chat_completion import ChatCompletion, Request, Response
from aidial_sdk.chat_completion.request import ChatCompletionRequest
from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin
from aidial_sdk.deployment.tokenize import (
    TokenizeError,
    TokenizeInputRequest,
    TokenizeInputString,
    TokenizeOutput,
    TokenizeRequest,
    TokenizeResponse,
    TokenizeSuccess,
)
from aidial_sdk.deployment.truncate_prompt import (
    TruncatePromptError,
    TruncatePromptRequest,
    TruncatePromptResponse,
    TruncatePromptResult,
    TruncatePromptSuccess,
)
from aidial_sdk.exceptions import ResourceNotFoundError
from google.genai.client import Client as GenAIClient
from typing_extensions import override

from aidial_adapter_vertexai.adapters import get_chat_completion_model
from aidial_adapter_vertexai.chat.chat_completion_adapter import (
    ChatCompletionAdapter,
    TruncatedPrompt,
)
from aidial_adapter_vertexai.chat.consumer import ChoiceConsumer
from aidial_adapter_vertexai.chat.errors import UserError, ValidationError
from aidial_adapter_vertexai.chat.static_tools import StaticToolsConfig
from aidial_adapter_vertexai.chat.tools import ToolsConfig
from aidial_adapter_vertexai.deployments import ChatCompletionDeployment
from aidial_adapter_vertexai.dial_api.exceptions import dial_exception_decorator
from aidial_adapter_vertexai.dial_api.request import ModelParameters
from aidial_adapter_vertexai.dial_api.token_usage import TokenUsage
from aidial_adapter_vertexai.utils.log_config import app_logger as log
from aidial_adapter_vertexai.utils.not_implemented import is_implemented


class VertexAIChatCompletion(ChatCompletion):
    def __init__(self, client: GenAIClient):
        self.client = client

    async def _get_model(
        self, request: FromRequestDeploymentMixin
    ) -> ChatCompletionAdapter:
        return await get_chat_completion_model(
            deployment=ChatCompletionDeployment(request.deployment_id),
            api_key=request.api_key,
            client=self.client,
        )

    @dial_exception_decorator
    async def chat_completion(self, request: Request, response: Response):
        response.set_model(request.deployment_id)

        model = await self._get_model(request)
        tools = ToolsConfig.from_request(request)
        static_tools = StaticToolsConfig.from_request(request)
        prompt = await model.parse_prompt(tools, static_tools, request.messages)

        if isinstance(prompt, UserError):
            await prompt.report_usage(response)
            raise prompt

        params = ModelParameters.create(request)

        # Currently n>1 is emulated by calling the model n times
        n = params.n or 1
        params.n = None

        if params.max_prompt_tokens is None:
            truncated_prompt = TruncatedPrompt(
                prompt=prompt, discarded_messages=[]
            )
        else:
            if not is_implemented(model.truncate_prompt):
                raise ValidationError(
                    "max_prompt_tokens request parameter is not supported"
                )
            truncated_prompt = await model.truncate_prompt(
                prompt, params.max_prompt_tokens
            )

        async def generate_response(usage: TokenUsage) -> None:
            with ChoiceConsumer(response=response) as consumer:
                await model.chat(params, consumer, truncated_prompt.prompt)
                usage.accumulate(consumer.usage)

            log.debug(
                f"finish_reason[{consumer.choice_idx}]: {consumer.finish_reason}"
            )

        usage = TokenUsage()

        await asyncio.gather(*(generate_response(usage) for _ in range(n)))

        log.debug(f"usage: {usage}")
        response.set_usage(usage.prompt_tokens, usage.completion_tokens)

        if params.max_prompt_tokens is not None:
            response.set_discarded_messages(truncated_prompt.discarded_messages)

    @override
    @dial_exception_decorator
    async def tokenize(self, request: TokenizeRequest) -> TokenizeResponse:
        model = await self._get_model(request)

        if not is_implemented(
            model.count_completion_tokens
        ) or not is_implemented(model.count_prompt_tokens):
            raise ResourceNotFoundError("The endpoint is not implemented")

        outputs: List[TokenizeOutput] = []
        for input in request.inputs:
            match input:
                case TokenizeInputRequest():
                    outputs.append(
                        await self._tokenize_request(model, input.value)
                    )
                case TokenizeInputString():
                    outputs.append(
                        await self._tokenize_string(model, input.value)
                    )
                case _:
                    assert_never(input.type)
        return TokenizeResponse(outputs=outputs)

    async def _tokenize_string(
        self, model: ChatCompletionAdapter, value: str
    ) -> TokenizeOutput:
        try:
            tokens = await model.count_completion_tokens(value)
            return TokenizeSuccess(token_count=tokens)
        except Exception as e:
            return TokenizeError(error=str(e))

    async def _tokenize_request(
        self, model: ChatCompletionAdapter, request: ChatCompletionRequest
    ) -> TokenizeOutput:
        try:
            tools = ToolsConfig.from_request(request)
            static_tools = StaticToolsConfig.from_request(request)
            prompt = await model.parse_prompt(
                tools, static_tools, request.messages
            )
            if isinstance(prompt, UserError):
                raise prompt

            token_count = await model.count_prompt_tokens(prompt)
            return TokenizeSuccess(token_count=token_count)
        except Exception as e:
            return TokenizeError(error=str(e))

    @override
    @dial_exception_decorator
    async def truncate_prompt(
        self, request: TruncatePromptRequest
    ) -> TruncatePromptResponse:
        model = await self._get_model(request)

        if not is_implemented(model.truncate_prompt):
            raise ResourceNotFoundError("The endpoint is not implemented")

        outputs: List[TruncatePromptResult] = []
        for input in request.inputs:
            outputs.append(await self._truncate_prompt_request(model, input))
        return TruncatePromptResponse(outputs=outputs)

    async def _truncate_prompt_request(
        self, model: ChatCompletionAdapter, request: ChatCompletionRequest
    ) -> TruncatePromptResult:
        try:
            if request.max_prompt_tokens is None:
                raise ValidationError("max_prompt_tokens is required")

            tools = ToolsConfig.from_request(request)
            static_tools = StaticToolsConfig.from_request(request)
            prompt = await model.parse_prompt(
                tools, static_tools, request.messages
            )

            truncated_prompt = await model.truncate_prompt(
                prompt, request.max_prompt_tokens
            )
            return TruncatePromptSuccess(
                discarded_messages=truncated_prompt.discarded_messages
            )
        except Exception as e:
            return TruncatePromptError(error=str(e))
