from sentence_transformers import SentenceTransformer
import transformers
import logging
from typing import List
from typing import Dict

from dataiku.huggingface.pipeline_batching import ModelPipelineBatching
from dataiku.huggingface.types import ProcessSingleEmbeddingCommand
from dataiku.huggingface.types import ProcessSingleEmbeddingResponse
from dataiku.huggingface.env_collector import extract_model_architecture

logger = logging.getLogger(__name__)

class ModelPipelineTextEmbeddingExtraction(ModelPipelineBatching[ProcessSingleEmbeddingCommand, ProcessSingleEmbeddingResponse]):

    def __init__(self, model_path, batch_size):
        super().__init__(batch_size=batch_size)
        self._model = SentenceTransformer(model_path)

        try:
            # May only work for transformers-compatible models
            model_config = transformers.AutoConfig.from_pretrained(model_path)
            self.model_tracking_data["model_architecture"] = extract_model_architecture(model_config)
        except:
            logger.exception(f"Failed to load model config for {model_path}")
            self.model_tracking_data["model_architecture"] = "unknown"

        self.model_tracking_data["task"] = "text-embedding"
        self.model_tracking_data["used_engine"] = "sentence-transformers"

    def _get_inputs(self, requests: List[ProcessSingleEmbeddingCommand]) -> List[str]:
        return [request["query"]["text"] for request in requests]

    def _get_params(self, request: ProcessSingleEmbeddingCommand) -> Dict:
        return {"text_overflow_mode": request["settings"]["textOverflowMode"]}

    def _run_inference(self, text_batch: List[str], params: Dict) -> List[ProcessSingleEmbeddingResponse]:
        tokenized_batch = self._model.tokenizer(text_batch, truncation=False)
        if params["text_overflow_mode"] != "TRUNCATE":
            # sentence-transformer truncates by default the texts when they are longer than what the model supports.
            # if a different behavior is needed (failure mode) we have to detect text overflow prior to use the model:

            # For more info about the tokenizer response type see https://huggingface.co/docs/transformers/preprocessing#natural-language-processing
            longest_tokenized_text = max(tokenized_batch['input_ids'], key=lambda x: len(x))
            token_count = len(longest_tokenized_text)
            if token_count > self._model.max_seq_length:
                raise Exception(f"Found a text longer ({token_count} tokens) than what the model supports ({self._model.max_seq_length} tokens limit). Try activating chunking, reducing the chunk size or select a bigger model.")

        embeddings = self._model.encode(text_batch, batch_size=len(text_batch)).tolist()
        return [{"embedding": embeddings[i], "promptTokens": len(tokenized_batch[i])} for i in range(len(embeddings))]

    def _parse_response(self, response: ProcessSingleEmbeddingResponse, request: ProcessSingleEmbeddingCommand) -> ProcessSingleEmbeddingResponse:
        return response
