from typing import Any, List, Optional

from common.backend.models.base import MediaSummary, RetrieverMode, UploadChainTypes


class APIResponseProcessor:
    def __init__(self, api_response: dict):
        self.api_response = api_response
        self.validate_response()

    def validate_response(self):
        """Validates the structure and content of the incoming JSON."""
        if not isinstance(self.api_response, dict):
            raise ValueError("API response must be a dictionary.")

        # Check for required keys
        required_keys = ["query", "conversation_id"]

        for key in required_keys:
            if key not in self.api_response:
                raise ValueError(f"Missing key in API response: {key}")

    def extract_query(self) -> Any:
        """Extracts the main query from the API response."""
        return self.api_response.get("query", "")

    def extract_query_index(self) -> Any:
        """Extracts the discussion query index from the API response."""
        return self.api_response.get("query_index", "")

    def extract_conversation_id(self) -> Any:
        """Extract the current conversation id"""
        return self.api_response.get("conversation_id", None)

    def extract_filters(self) -> Any:
        return self.api_response.get("filters", None)

    def extract_chain_type(self) -> Optional[str]:
        # TODO maybe not even needed here
        media_summaries = self.extract_media_summaries()
        if media_summaries:
            chain_type = UploadChainTypes.SHORT_DOCUMENT.value
            chain_types = [item.get("chain_type") for item in media_summaries if item.get("chain_type")]
            if all(chain_type == UploadChainTypes.IMAGE.value for chain_type in chain_types):
                chain_type = UploadChainTypes.IMAGE.value
            elif (
                UploadChainTypes.DOCUMENT_AS_IMAGE.value in chain_types
                and UploadChainTypes.SHORT_DOCUMENT.value not in chain_types
            ):
                chain_type = UploadChainTypes.DOCUMENT_AS_IMAGE.value
            return chain_type
        return None

    def extract_media_summaries(self) -> Optional[List[MediaSummary]]:
        if media_summaries := self.api_response.get("media_summaries"):
            if len(media_summaries) > 0:
                return [MediaSummary(**summary) for summary in media_summaries]
        return None

    def extract_user_profile(self) -> Any:
        return self.api_response.get("user_profile", None)

    def extract_knowledge_bank_id(self) -> Optional[str]:
        """Extracts the knowledge bank id from the API response."""
        # TODO: this is a temp fix for the current logging dataset schema
        if retrieval_selection := self.api_response.get("retrieval_selection"):
            kb_id = [ret.get("source") for ret in retrieval_selection if ret.get("type") == RetrieverMode.KB.value]
            if len(kb_id) > 0:
                return str(kb_id[0])
        return None

    def extract_retrieval_enabled(self) -> Any:
        """Extracts if the knowledge bank is enabled or not from the API response."""
        return self.api_response.get("retrieval_enabled")

    def extract_answer(self) -> Any:
        return self.api_response.get("answer", "")

    def extract_sources(self) -> Any:
        return self.api_response.get("sources", [])

    def extract_retrieval_selection(self) -> Any:
        return self.api_response.get("retrieval_selection")

    def extract_conversation_type(self) -> Any:
        return self.api_response.get("conversation_type")


class APIConvTitleRequestProcessor:
    def __init__(self, api_response: dict):
        self.api_response = api_response
        self.validate_response()

    def validate_response(self):
        if not isinstance(self.api_response, dict):
            raise ValueError("API response must be a dictionary.")
        required_keys = ["query", "answer"]
        for key in required_keys:
            if key not in self.api_response:
                raise ValueError(f"Missing key in API response: {key}")

    def extract_query(self) -> Any:
        """Extracts the main query from the API response."""
        return self.api_response.get("query", "")

    def extract_answer(self) -> Any:
        return self.api_response.get("answer", "")


class APIGeneralFeedbackRequestProcessor:
    def __init__(self, api_response: dict):
        self.api_response = api_response
        self.validate_response()

    def validate_response(self):
        if not isinstance(self.api_response, dict):
            raise ValueError("API response must be a dictionary.")
        required_keys = ["message"]
        for key in required_keys:
            if key not in self.api_response:
                raise ValueError(f"Missing key in API response: {key}")

    def extract_message(self) -> Any:
        """Extracts the feedback message from the API response."""
        return self.api_response.get("message", "")

    def extract_knowledge_bank_id(self) -> Any:
        """Extracts the knowledge bank id from the API response."""
        return self.api_response.get("knowledge_bank_id", None)