import asyncio
from typing import List, Optional, assert_never

from aidial_sdk.chat_completion import ChatCompletion, Request, Response
from aidial_sdk.chat_completion.request import ChatCompletionRequest
from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin
from aidial_sdk.deployment.tokenize import (
    TokenizeError,
    TokenizeInputRequest,
    TokenizeInputString,
    TokenizeOutput,
    TokenizeRequest,
    TokenizeResponse,
    TokenizeSuccess,
)
from aidial_sdk.deployment.truncate_prompt import (
    TruncatePromptError,
    TruncatePromptRequest,
    TruncatePromptResponse,
    TruncatePromptResult,
    TruncatePromptSuccess,
)
from typing_extensions import override

from aidial_adapter_bedrock.adapter_deployments import (
    AdapterChatCompletionDeployment,
)
from aidial_adapter_bedrock.aws_client_config import AWSClientConfigFactory
from aidial_adapter_bedrock.dial_api.request import ModelParameters
from aidial_adapter_bedrock.dial_api.token_usage import TokenUsage
from aidial_adapter_bedrock.llm.chat_model import (
    ChatCompletionAdapter,
    TextCompletionAdapter,
)
from aidial_adapter_bedrock.llm.consumer import ChoiceConsumer
from aidial_adapter_bedrock.llm.errors import UserError, ValidationError
from aidial_adapter_bedrock.llm.model.adapter import get_bedrock_adapter
from aidial_adapter_bedrock.llm.truncate_prompt import DiscardedMessages
from aidial_adapter_bedrock.server.exceptions import (
    dial_exception_decorator,
    not_implemented_handler,
)
from aidial_adapter_bedrock.utils.log_config import app_logger as log


class BedrockChatCompletion(ChatCompletion):
    deployment: AdapterChatCompletionDeployment

    def __init__(self, deployment: AdapterChatCompletionDeployment) -> None:
        self.deployment = deployment

    async def _get_model(
        self, request: FromRequestDeploymentMixin
    ) -> ChatCompletionAdapter:
        aws_client_config = await AWSClientConfigFactory(
            request=request,
        ).get_client_config()

        return await get_bedrock_adapter(
            deployment=self.deployment,
            api_key=request.api_key,
            aws_client_config=aws_client_config,
        )

    @dial_exception_decorator
    async def chat_completion(self, request: Request, response: Response):
        response.set_model(self.deployment.upstream_deployment_id)

        model = await self._get_model(request)
        params = ModelParameters.create(request)

        discarded_messages: Optional[DiscardedMessages] = None

        async def generate_response(usage: TokenUsage) -> None:
            nonlocal discarded_messages

            with ChoiceConsumer(response=response) as consumer:
                if isinstance(model, TextCompletionAdapter):
                    consumer.set_tools_emulator(
                        model.tools_emulator(params.tool_config)
                    )

                try:
                    await model.chat(consumer, params, request.messages)
                except UserError as e:
                    await e.report_usage(consumer.choice)
                    await response.aflush()
                    raise e

                usage.accumulate(consumer.usage)
                discarded_messages = consumer.discarded_messages

        usage = TokenUsage()

        await asyncio.gather(
            *(generate_response(usage) for _ in range(request.n or 1))
        )

        log.debug(f"usage: {usage}")
        response.set_usage(usage.prompt_tokens, usage.completion_tokens)

        if discarded_messages is not None:
            response.set_discarded_messages(discarded_messages)

    @override
    @dial_exception_decorator
    @not_implemented_handler
    async def tokenize(self, request: TokenizeRequest) -> TokenizeResponse:
        model = await self._get_model(request)

        outputs: List[TokenizeOutput] = []
        for input in request.inputs:
            match input:
                case TokenizeInputRequest():
                    outputs.append(
                        await self._tokenize_request(model, input.value)
                    )
                case TokenizeInputString():
                    outputs.append(
                        await self._tokenize_string(model, input.value)
                    )
                case _:
                    assert_never(input.type)
        return TokenizeResponse(outputs=outputs)

    async def _tokenize_string(
        self, model: ChatCompletionAdapter, value: str
    ) -> TokenizeOutput:
        try:
            tokens = await model.count_completion_tokens(value)
            return TokenizeSuccess(token_count=tokens)
        except NotImplementedError:
            raise
        except Exception as e:
            return TokenizeError(error=str(e))

    async def _tokenize_request(
        self, model: ChatCompletionAdapter, request: ChatCompletionRequest
    ) -> TokenizeOutput:
        params = ModelParameters.create(request)

        try:
            token_count = await model.count_prompt_tokens(
                params, request.messages
            )
            return TokenizeSuccess(token_count=token_count)
        except NotImplementedError:
            raise
        except Exception as e:
            return TokenizeError(error=str(e))

    @override
    @dial_exception_decorator
    @not_implemented_handler
    async def truncate_prompt(
        self, request: TruncatePromptRequest
    ) -> TruncatePromptResponse:
        model = await self._get_model(request)

        outputs: List[TruncatePromptResult] = []
        for input in request.inputs:
            outputs.append(await self._truncate_prompt_request(model, input))
        return TruncatePromptResponse(outputs=outputs)

    async def _truncate_prompt_request(
        self, model: ChatCompletionAdapter, request: ChatCompletionRequest
    ) -> TruncatePromptResult:
        try:
            params = ModelParameters.create(request)

            if params.max_prompt_tokens is None:
                raise ValidationError("max_prompt_tokens is required")

            discarded_messages = await model.compute_discarded_messages(
                params, request.messages
            )

            return TruncatePromptSuccess(
                discarded_messages=discarded_messages or []
            )
        except NotImplementedError:
            raise
        except Exception as e:
            return TruncatePromptError(error=str(e))
