import dataiku
from dataiku.llm.types import RAGRerankingSettings
from dataikuapi.dss.llm import DSSLLMRerankingResponse
from dataikuapi.dss.llm_tracing import SpanBuilder

from enum import Enum
from typing import List, TYPE_CHECKING, Dict, Optional
from logging import Logger

if TYPE_CHECKING:
    from langchain_core.documents import Document


class ScoreType(str, Enum):
    RETRIEVAL = "score"
    RERANKING = "reranking"


class DocumentWithScores:
    def __init__(self, document: "Document", score: Optional[float] = None):
        self.document = document
        self._scores: Dict[ScoreType, float] = {}
        if score is not None:
            self.set_score(ScoreType.RETRIEVAL, score)

    @property
    def metadata(self):
        return self.document.metadata

    @property
    def page_content(self):
        return self.document.page_content

    def set_score(self, name: ScoreType, score: float):
        self._scores[name] = score

    def get_score(self, name: ScoreType) -> Optional[float]:
        return self._scores.get(name)


class Reranking:
    def __init__(self, settings: RAGRerankingSettings):
        self.llm_id = settings.get("llmId")
        if self.llm_id is None:
            raise ValueError("No reranking LLM ID provided")
        self.max_documents = settings.get("maxDocuments", 5)

    def _rerank(self, query, docs) -> DSSLLMRerankingResponse:
        llm = dataiku.api_client().get_default_project().get_llm(self.llm_id)
        reranking_query = llm.new_reranking()
        reranking_query.with_query(query)
        for doc in docs:
            reranking_query.with_document(doc.page_content)
        return reranking_query.execute()

    def _reorder_and_filter_top_documents(self, documents: List[DocumentWithScores], ranked_documents: List[DSSLLMRerankingResponse.RankedDocument]) -> List[DocumentWithScores]:
        reranked_documents = []
        for ranked_doc in ranked_documents:
            if ranked_doc.index < 0 or ranked_doc.index >= len(documents):
                raise ValueError("Index out of range (%s) for reranking response" % ranked_doc.index)

            reranked_document = documents[ranked_doc.index]
            reranked_document.set_score(ScoreType.RERANKING, ranked_doc.relevance_score)
            reranked_documents.append(reranked_document)
        return reranked_documents[:self.max_documents]

    def rerank_with_scores(self, query: str, documents: List[DocumentWithScores], trace: SpanBuilder, logger: Logger) -> List[DocumentWithScores]:
        if len(documents) == 0:
            logger.info("No documents to rerank, skipping reranking query")
            return documents
        with trace.subspan("RERANKING_DOCUMENTS") as reranking_span:
            response = self._rerank(query, documents)
            if response.trace is not None:
                reranking_span.append_trace(response.trace)
            ranked_documents = response.documents  # raises an error if the request failed
            logger.info("Reranked %s documents" % len(ranked_documents))
            return self._reorder_and_filter_top_documents(documents, ranked_documents)

    def rerank(self, query: str, documents: List["Document"], trace: SpanBuilder, logger: Logger) -> List["Document"]:
        documents = [DocumentWithScores(document) for document in documents]
        return [doc.document for doc in self.rerank_with_scores(query, documents, trace, logger)]
