import json
from abc import ABC, abstractmethod
from typing import Generic, List, TypeVar, assert_never

from aidial_sdk.chat_completion.request import Role
from google.genai.types import Content as GenAIContent
from google.genai.types import Part as GenAIPart
from pydantic.v1 import BaseModel
from vertexai.preview.generative_models import ChatSession, Content, Part

from aidial_adapter_vertexai.chat.errors import ValidationError

PartT = TypeVar("PartT")
ContentT = TypeVar("ContentT")
GeminiConversationT = TypeVar("GeminiConversationT")


class ConversationFactoryBase(
    ABC, Generic[PartT, ContentT, GeminiConversationT]
):
    @abstractmethod
    def create_multi_modal_part(self, data: bytes, mime_type: str) -> PartT: ...

    @abstractmethod
    def create_text_part(self, text: str) -> PartT: ...

    @abstractmethod
    def create_function_call_part(self, name: str, args: str) -> PartT: ...

    @abstractmethod
    def create_function_result_part(self, name: str, args: str) -> PartT: ...

    @abstractmethod
    def create_content(self, role: Role, parts: List[PartT]) -> ContentT: ...

    @abstractmethod
    def create_conversation(
        self, system_instruction: List[PartT] | None, contents: List[ContentT]
    ) -> GeminiConversationT: ...


class GeminiConversationBase(BaseModel, Generic[PartT, ContentT]):
    system_instruction: List[PartT] | None = None
    contents: List[ContentT]

    class Config:
        arbitrary_types_allowed = True


class GeminiConversation(GeminiConversationBase[Part, Content]):
    pass


class ConversationFactory(
    ConversationFactoryBase[Part, Content, GeminiConversation]
):
    @staticmethod
    def _to_gemini_role(role: Role) -> str:
        match role:
            case Role.SYSTEM:
                raise ValidationError(
                    "System messages other than the first system message are not allowed"
                )
            case Role.USER | Role.FUNCTION | Role.TOOL:
                return ChatSession._USER_ROLE
            case Role.ASSISTANT:
                return ChatSession._MODEL_ROLE
            case _:
                assert_never(role)

    def create_multi_modal_part(self, data: bytes, mime_type: str) -> Part:
        return Part.from_data(data=data, mime_type=mime_type)

    def create_text_part(self, text: str) -> Part:
        return Part.from_text(text)

    def create_function_call_part(self, name: str, args: str) -> Part:
        try:
            args = json.loads(args)
            return Part.from_dict(
                {"function_call": {"name": name, "args": args}}
            )
        except Exception:
            raise ValidationError(
                "Function call arguments must be a valid JSON"
            )

    def create_function_result_part(self, name: str, args: str) -> Part:
        try:
            args = json.loads(args)
        except Exception:
            args = args

        if isinstance(args, dict):
            return Part.from_function_response(name, args)

        return Part.from_function_response(name, {"content": args})

    def create_content(self, role: Role, parts: List[Part]) -> Content:
        return Content(role=self._to_gemini_role(role), parts=parts)

    def create_conversation(
        self, system_instruction: List[Part] | None, contents: List[Content]
    ) -> GeminiConversation:
        return GeminiConversation(
            system_instruction=system_instruction, contents=contents
        )


class GeminiGenAIConversation(GeminiConversationBase[GenAIPart, GenAIContent]):
    pass


class GenAIConversationFactory(
    ConversationFactoryBase[GenAIPart, GenAIContent, GeminiGenAIConversation]
):
    @staticmethod
    def to_gemini_genai_role(role: Role) -> str:
        match role:
            case Role.SYSTEM:
                raise ValidationError(
                    "System messages other than the first system message are not allowed"
                )
            case Role.USER | Role.FUNCTION | Role.TOOL:
                return "user"
            case Role.ASSISTANT:
                return "model"
            case _:
                assert_never(role)

    def create_multi_modal_part(self, data: bytes, mime_type: str) -> GenAIPart:
        return GenAIPart.from_bytes(data=data, mime_type=mime_type)

    def create_text_part(self, text: str) -> GenAIPart:
        return GenAIPart.from_text(text)

    def create_function_call_part(self, name: str, args: str) -> GenAIPart:
        try:
            return GenAIPart.from_function_call(name, json.loads(args))
        except Exception:
            raise ValidationError(
                "Function call arguments must be a valid JSON"
            )

    def create_function_result_part(self, name: str, args: str) -> GenAIPart:
        try:
            processed_args = json.loads(args)
        except Exception:
            processed_args = args

        if isinstance(processed_args, dict):
            return GenAIPart.from_function_response(name, processed_args)

        return GenAIPart.from_function_response(
            name, {"output": processed_args}
        )

    def create_content(
        self, role: Role, parts: List[GenAIPart]
    ) -> GenAIContent:
        return GenAIContent(role=self.to_gemini_genai_role(role), parts=parts)

    def create_conversation(
        self,
        system_instruction: List[GenAIPart] | None,
        contents: List[GenAIContent],
    ) -> GeminiGenAIConversation:
        return GeminiGenAIConversation(
            system_instruction=system_instruction, contents=contents
        )
