import logging
from abc import abstractmethod
from math import ceil
from os import environ

import pandas as pd

import dataiku
from dataiku.external_ml.proxy_model.common.inputformat import get_valid_input_formats, BaseWriter
from dataiku.external_ml.proxy_model.common.outputformat import BaseReader, guess_input_and_output_formats, \
    guess_input_format_from_output_format

logger = logging.getLogger(__name__)


class ProxyModel(object):
    """Base class for Proxy model

    Each implementation should be registered in dataiku.external_ml.proxy_model.__init__.py
    """

    def __init__(self):
        """Should set attributes for URL of the model, id of the model, eventual headers etc."""
        raise NotImplementedError

    def predict(self, input_df):
        """Should output a prediction for a given input dataframe by calling the right endpoint

        Should output a pd.DataFrame with at least the "prediction" column and optionnaly
        columns proba_label0, proba_label1 etc. for classification
        """

        raise NotImplementedError

    @staticmethod
    def get_connection_info(connection_name, context_project_key, expected_type, expected_type_desc):
        connection = dataiku.api_client().get_connection(connection_name)
        connection_info = connection.get_info(context_project_key)
        if connection_info is None or connection_info["type"] != expected_type:
            raise ValueError("The connection " + connection_name + " is not a " + expected_type_desc)
        if "params" not in connection_info:
            raise Exception("Permission to view details of the connection " + connection_name + "required")

        return connection_info

    @staticmethod
    def runs_on_real_api_node():
        return not environ.get("DKU_LAMBDA_DEVSERVER") and environ.get("DKU_CURRENT_APISERVICE")


class ProxyModelEndpointClient(object):
    """This is base class for proxy model low-level clients.
     Subclasses are typically adapters around the client provided by cloud SDKs.
     The purpose is to standardize how we call the client, so that we can
     share higher-level logic (guessing input/output formats and batching)

     Implementing this class is required only if you want to subclass ChunkedAndFormatGuessingProxyModel
    """

    @abstractmethod
    def call_endpoint(self, data, content_type):
        """Should call an endpoint with 'data' using the 'content_type' content-type
        Should output the respons.
        """
        raise NotImplementedError

    @staticmethod
    def get_inference_type():
        """Non-scoring inference requests should be flagged, so they can be ignored in the evaluation recipe of
        the prediction logs. Import evaluation requests, test queries, interpretability and evaluation requests will be
        flagged as 'dku_misc', and the ones from scoring recipes as 'dku_scoring'"""
        return environ.get("_DKU_PROXY_MODELS_INFERENCE_TYPE", "dku_misc")


class BatchConfiguration(object):
    def __init__(self, max_request_size, initial_batch_size=50000):
        self.max_request_size = max_request_size
        self.initial_batch_size = initial_batch_size


class ChunkedAndFormatGuessingProxyModel(ProxyModel):
    """A more elaborate ProxyModel which provides support for:
        - chunking requests to work with request size limits
        - multiple input/output formats with automatic guessing
    """

    def __init__(self, prediction_type, value_to_class, supported_input_formats, supported_output_formats, input_format, output_format, batch_config):
        """

        :param prediction_type:
        :param value_to_class:
        :param supported_input_formats: List of types that are subtypes of BaseWriter
        :type supported_input_formats: list[type]
        :param supported_output_formats: List of types that are subtypes of BaseReader
        :type supported_output_formats: list[type]
        :param input_format:
        :type input_format: str
        :param output_format:
        :type output_format: str
        :param batch_config:
        """
        self.prediction_type = prediction_type
        self.value_to_class = value_to_class
        self.supported_input_formats = supported_input_formats
        self.supported_output_formats = supported_output_formats
        # :type self.input_format: Optional[BaseWriter]
        self.input_format = None
        # :type self.input_format: Optional[BaseReader]
        self.output_format = None
        self.batch_config = batch_config
        self.current_batch = 0
        self.current_batch_size = batch_config.initial_batch_size
        self.client = self.get_client()
        self.set_input_output_formats(supported_input_formats, supported_output_formats, input_format, output_format)

    def set_input_output_formats(self, supported_input_formats, supported_output_formats, input_format, output_format):
        """

        :param supported_input_formats: List of types that are subtypes of BaseWriter
        :type supported_input_formats: list[type]
        :param supported_output_formats: List of types that are subtypes of BaseReader
        :type supported_output_formats: list[type]
        :param input_format:
        :type input_format: str
        :param output_format:
        :type output_format: str
        """
        self.input_format = self.get_input_format(supported_input_formats, input_format)
        logger.info("Using writer: {}".format(type(self.input_format)))
        self.output_format = self.get_output_format(supported_output_formats, output_format)
        logger.info("Using reader: {}".format(type(self.output_format)))

    def get_client(self):
        """
        Returns a ProxyModelEndpointClient which will be used to actually call endpoints.
        """
        raise NotImplementedError

    def guess_formats(self, input_df, input_format_name, output_format_name):
        """
        Cycle through supported input and output formats to guess which one should be used for this endpoint.

        :param input_df: dataframe used to create requests
        :type input_df: pd.DataFrame
        :param input_format_name: input format name if forcing a specific input format, or GUESS
        :type input_format_name: str
        :param output_format_name: output format name if forcing a specific output format, or GUESS
        :type output_format_name: str
        :return: guessed input and output format as a tuple
        :rtype: (BaseWriter, BaseReader)
        """
        # Work around the fact that some endpoints don't like being called with only row.
        if len(input_df) == 1:
            input_df = pd.concat([input_df, input_df], ignore_index=True)

        # No need to pass the full df since we are only detecting how we can talk to that endpoint. The 2 first rows should be enough (1 is not good,
        # some of our testing endpoints will complain when there is only 1 row)
        first_two_rows = input_df.head(2)
        input_format = None
        output_format = None

        if input_format_name == "GUESS":
            logger.info("Trying to guess the writer...")
            valid_input_formats_with_predictions = get_valid_input_formats(self.supported_input_formats, self.client, first_two_rows)
            logger.info("{} input formats are valid. We will test them against output format".format(len(valid_input_formats_with_predictions)))
        else:
            self.input_format = self.get_input_format(self.supported_input_formats, input_format_name)
            logger.info("Using writer: {}".format(type(self.input_format)))
            input_format = self.input_format
            try:
                predictions = input_format.write(first_two_rows)
                valid_input_formats_with_predictions = [(input_format, predictions)]
            except Exception as e:
                msg = "Error while testing the endpoint with manually selected input format {} and data {}: {}." \
                      " Make sure that your dataset columns aligns with what your model expects or try another input format.".format(input_format.NAME, first_two_rows, e)
                logger.error(msg)
                raise Exception(msg)

        if output_format_name == "GUESS":
            logger.info("Trying to guess the reader...")
            input_format, output_format = guess_input_and_output_formats(self.supported_output_formats, self.prediction_type, first_two_rows, valid_input_formats_with_predictions, self.value_to_class)
        else:
            self.output_format = self.get_output_format(self.supported_output_formats, output_format_name)
            logger.info("Using reader: {}".format(type(self.output_format)))
            output_format = self.output_format
            if input_format is None:
                try:
                    input_format = guess_input_format_from_output_format(output_format, first_two_rows, valid_input_formats_with_predictions)
                except Exception as e:
                    msg = "Error while testing the endpoint with manually selected output format {} and result returned by" \
                          " the endpoint when submitting {}: {}. Try another output format.".format(output_format.NAME, first_two_rows, e)
                    logger.error(msg)
                    raise Exception(msg)

        if input_format is None:
            raise Exception("Could not find any compatible writing format for this endpoint.")
        if output_format is None:
            raise Exception("Could not find any compatible reading format for this endpoint.")

        if input_format_name == "GUESS":
            logger.info("Will use auto-guessed writer: {}".format(type(input_format)))
        if output_format_name == "GUESS":
            logger.info("Will use auto-guessed reader: {}".format(type(output_format)))

        return input_format, output_format

    def predict(self, input_df):
        """Split df so that the json size in bytes of each sub-df is smaller that
        max_request_size. Call predict_func on each sub batch and concatenate the results.
        """

        # Work around the fact that some endpoints don't like being called with only row.
        if len(input_df) == 1:
            single_row_mode = True
            input_df = pd.concat([input_df, input_df], ignore_index=True)
        else:
            single_row_mode = False

        logger.info("Predicting using {} on input_df with shape {}".format(self.__class__.__name__, input_df.shape))
        input_df = input_df.copy().reset_index(drop=True)
        max_request_size_margin = self.batch_config.max_request_size - 20000  # save space for headers
        assert max_request_size_margin > 0, "max_request_size too small"
        if self.current_batch % 100 == 0:
            logger.info("Resetting batch size to initial {}".format(self.batch_config.initial_batch_size))
            self.current_batch_size = self.batch_config.initial_batch_size

        logger.info("Starting batch predict with df shape {}".format(input_df.shape))
        results = []
        start = 0
        while start < len(input_df):
            end = min(start + self.current_batch_size, len(input_df))
            logger.info("Batch from index start={} to end={} batch_size={}".format(start, end, self.current_batch_size))
            request_size = self.input_format.compute_request_size(input_df.iloc[start:end])
            if request_size > max_request_size_margin:
                logger.info("Request size={} exceed max_request_size_margin={}".format(request_size,
                                                                                       max_request_size_margin))
                self.current_batch_size = round(self.current_batch_size / ceil(request_size / max_request_size_margin))
                logger.info("New attempt with batch size {}".format(self.current_batch_size))
                if self.current_batch_size == 0:
                    raise "Evaluation dataset contains rows bigger than the max {} Mb limitation" \
                        .format(self.batch_config.max_request_size / 1e6)
            else:
                logger.info("Requesting external API for index {} to {}".format(start, end))
                sub_input_df = input_df.iloc[start:end]
                try:
                    raw_results = self.input_format.write(sub_input_df)
                except Exception as e:
                    logger.error(e)
                    raise Exception("Could not query endpoint, please check that your configuration is correct (including input format). Please also check "
                                    "that your input dataset matches what the endpoint expects (number, name and order of columns, data types, unexpected, "
                                    "missing or NaN values, ...). "
                                    "Error was: {}".format(str(e)))
                try:
                    parsed_results = self.output_format.read(raw_results)
                except Exception as e:
                    logger.error(e)
                    raise Exception("Could not parse endpoint results, please check that your configuration is correct (including output format). "
                                    "Error was: {}".format(str(e)))
                expected_len = len(sub_input_df)
                actual_len = len(parsed_results)
                if expected_len != actual_len:
                    raise ValueError("Sent {} input row(s), received {} prediction(s). "
                                     "Please check that your output format is correct.".format(expected_len, actual_len))
                results.append(parsed_results)
                logger.info("Success, result size={}".format(len(results)))
                start = min(start + self.current_batch_size, len(input_df))
        self.current_batch = self.current_batch + 1
        results = pd.concat(results, ignore_index=True)
        if single_row_mode:
            return results.head(1)
        else:
            return results

    def get_input_format(self, supported_input_formats, input_format_name):
        """

        :param supported_input_formats: List of types that are subtypes of BaseWriter
        :type supported_input_formats: list[type]
        :param input_format_name:
        :type input_format_name: str
        :rtype: BaseWriter
        """
        if input_format_name is not None and input_format_name != "GUESS":
            for supported_input_format in supported_input_formats:
                if supported_input_format.NAME == input_format_name:
                    return supported_input_format(self.client)
            raise ValueError("Unsupported input format: {}".format(input_format_name))
        else:
            return None

    def get_output_format(self, supported_output_formats, output_format_name):
        """

        :param supported_output_formats: List of types that are subtypes of BaseReader
        :type supported_output_formats: list[type]
        :param output_format_name:
        :type output_format_name: str
        :rtype: BaseReader
        """
        if output_format_name is not None and output_format_name != "GUESS":
            for supported_output_format in supported_output_formats:
                if supported_output_format.NAME == output_format_name:
                    return supported_output_format(self.prediction_type, self.value_to_class)
            raise ValueError("Unsupported output format: {}".format(output_format_name))
        else:
            return None

    @staticmethod
    def get_proxy():
        return environ.get("_PROXY_MODEL_PROXY", None)
