from dataiku.llm.guardrails import BaseGuardrail
import dataiku
import logging
from sentence_splitter import split_text_into_sentences
from transformers import pipeline

#result = classifier("Tall people are so clumsy")

class BiasDetectorGuardrail(BaseGuardrail):
    def set_config(self, config, plugin_config):
        self.config = config

        

        if self.config["model"] == "MAXIMUS":

            self.pipe = pipeline("text-classification", model="maximuspowers/bias-type-classifier", return_all_scores=True)
            # Disabled: nationality, socioeconomic, educational, political, physical, age
            self.enabled_labels = set(["racial", "religious", "gender",  "sexuality", "disability"])
            self.config["threshold"] = 0.9
        else:
            self.pipe = pipeline("text-classification", model="d4data/bias-detection-model", tokenizer="d4data/bias-detection-model")
            self.config["threshold"] = 0.95

    def check_text(self, input, text, trace):
        logging.info("Checking text for bias: %s" % text)
        sentences = split_text_into_sentences(text, language='en')
        logging.info("Split into %d sentences: %s" % (len(sentences), sentences))
        
        # Model sometimes thinks an empty string is biased
        sentences = [s for s in sentences if len(s) > 0]
        
        predictions = self.pipe(sentences)

        dubious_sentences = []


        i = 0
        for prediction in predictions:

            if self.config["model"] == "MAXIMUS":

                logging.info("Prediction: %s" % prediction)
                for label_data in prediction:
                    if label_data["label"] in self.enabled_labels:
                        label = label_data["label"]
                        score = label_data["score"]
                        if score > self.config["threshold"]:
                            dubious_sentences.append({
                                "sentence": sentences[i],
                                "label": label,
                                "score": score
                            })
                            logging.info("Biased sentence detected: i=%s label=%s score=%s sentence=%s" % (i, label, score, sentences[i]))
                            break
                            
            else:
                label = prediction["label"]
                score = prediction["score"]

                if score > self.config["threshold"] and label == "Biased":
                    dubious_sentences.append({
                        "sentence": sentences[i],
                        "label": label,
                        "score": score
                    })
                    logging.info("Biased sentence detected: i=%s label=%s score=%s sentence=%s" % (i, label, score, sentences[i]))
                    break

            i = i+1

        

        if len(dubious_sentences) > 0:
            bias_confirmed = False

            if self.config["doubleCheck"]:
                logging.info("Asking a LLM to check %s controversial sentences" % len(dubious_sentences))
                dubious_text = "\n".join([p["sentence"] for p in dubious_sentences])

                llm = dataiku.api_client().get_default_project().get_llm(self.config["llm"])

                resp = llm.new_completion() \
                        .with_message("""
Your task is to check whether some sentences contain some biased or stereotypes.
It is critical that you distinguish actual bias from simple facts. For example, saying that a famous person had a certain
gender or nationality is a fact, not a bias, an you should not flag it.
Please answer with a simple word, either "biased" or "safe", in lowercase. No decorations, no quotes. Just a word, in lowercase.
""", "system").with_message(dubious_text).execute()

                if resp.text == "safe":
                    logging.info("LLM did not confirm bias")
                elif resp.text == "biased":
                    logging.info("LLM did confirm bias: %s" % resp.text)
                    bias_confirmed = True
                else:
                    raise Exception("Unexpected response from confirmation LLM: %s" % resp.text)

            else:
                bias_confirmed = True

            if bias_confirmed:

                trace.attributes["biasedSentences"] = dubious_sentences

                if self.config["action"] == "REJECT":
                    raise Exception("Biased or stereotyped sentences found in the text")
                elif self.config["action"] == "AUDIT":
                    input["responseGuardrailResponse"] = {
                        "action": "PASS_WITH_AUDIT",
                        "auditData": [{
                            "origin": "Bias Detection",
                            "biasedSentences": dubious_sentences
                        }]
                    }
                elif self.config["action"] == "REWRITE":
                    msgs = []
                    msgs.extend(input["completionQuery"]["messages"])
                    msgs.append({"role": "system", "content": "Your previous answer was not acceptable as it contained biased or stereotyped statements. Please rewrite"})
                    input["responseGuardrailResponse"] = {
                        "action": "RETRY",
                        "updatedMessagesForRetry": msgs,
                        "auditData": [{
                            "origin": "Bias Detection",
                            "biasedSentences": dubious_sentences
                        }]
                    }
                
    def process(self, input, trace):
        if "completionResponse" in input:
            text_to_check = input["completionResponse"]["text"]
            
            with trace.subspan("Checking text for bias") as subspan:
                self.check_text(input, text_to_check, trace)

        return input
