import logging
import random
from pprint import pformat

import numpy as np
from sklearn.preprocessing import LabelEncoder
from skorch import NeuralNetClassifier
from skorch import NeuralNetRegressor
from skorch.callbacks import Callback
from skorch.callbacks import EarlyStopping
from skorch.callbacks import InputShapeSetter
from torch import backends
from torch import cuda
from torch import manual_seed
from torch import nn
from torch import optim

logger = logging.getLogger(__name__)


class DeepNeuralNetworkModel(nn.Module):
    def __init__(self, hidden_layers, units, input_dim=1, output_dim=1, dropout=0):
        super(DeepNeuralNetworkModel, self).__init__()

        # The following seeds are not enough to ensure reproducibility.
        # Especially in case of hp search on multiple threads, the results are usually not deterministic.
        # As best effort we set the following seeds to make the trainings as reproducible as possible
        np.random.seed(0)
        random.seed(0)
        manual_seed(0)
        backends.cudnn.deterministic = True
        backends.cudnn.benchmark = False
        self.hidden_layers = hidden_layers
        self.input_dim = input_dim
        self.units = units
        self.output_dim = output_dim
        self.dropout = dropout
        self.dropout_or_identity_layer = nn.Dropout(self.dropout) if 0 < self.dropout <= 1 else nn.Identity()
        self.non_lin = nn.ReLU()
        self.first_dense = nn.Linear(self.input_dim, self.units)
        self.hidden_layers = nn.ModuleList([nn.Linear(self.units, self.units) for _ in range(self.hidden_layers)])
        # Seems to be needed to be defined in the constructor, otherwise : "Failed to train : optimizer got an empty parameter list"
        self.last_dense = nn.Linear(self.units, self.output_dim)
        self.double()

    def forward(self, x):
        x = self.first_dense(x)
        x = self.non_lin(x)
        for layer in self.hidden_layers:
            x = self.dropout_or_identity_layer(x)
            x = layer(x)
            x = self.non_lin(x)
        x = self.last_dense(x)
        return x


class OutputShapeSetter(Callback):
    def on_train_begin(self, net, X=None, y=None, **kwargs):
        # The label encoder needs to be re instantiated here
        # So that it works after DKUNeuralNetClassifier is cloned
        label_encoder = LabelEncoder()
        label_encoder.fit(y)
        net.set_params(module__output_dim=len(label_encoder.classes_))


class EpochCounter(Callback):
    def __init__(self, epochs=0):
        self.epochs = 0

    def on_epoch_end(self, net, dataset_train=None, dataset_valid=None, **kwargs):
        self.epochs = len(net.history)


class _ParamsRemapper:
    remapped_module_params = [
        "hidden_layers",
        "units",
    ]

    def _remove_prefix(self, key):
        if type(key) == str:
            # `module__` prefix
            if key.startswith("module__") and key.replace("module__", "", 1) in self.remapped_module_params:
                return key.replace("module__", "", 1)
        return key

    def _add_prefix(self, key):
        # `module__` prefix
        if key in self.remapped_module_params:
            return "module__"+key
        return key

    def add_prefix_to_dict(self, obj):
        return {self._add_prefix(key): val for key, val in obj.items()}

    def remove_prefix_to_dict(self, obj):
        return {self._remove_prefix(key): val for key, val in obj.items()}


class DKUNeuralNetRegressor(NeuralNetRegressor, _ParamsRemapper):
    """
    Skorch takes an un-instantiated pytorch module, and a lot of other arguments.
    - The arguments that are not prefixed are used directly by skorch (max_epochs, optimizer...)
    - The arguments prefixed with `module__` will be passed to the __init__ method of the module (module__hidden_layers, module__units...)
    - The arguments for the callbacks are prefixed with `callbacks__<callback_name>__` (callbacks__EarlyStopping__...)
    """
    def __init__(self, module=DeepNeuralNetworkModel,
                 optimizer=optim.Adam, criterion=nn.MSELoss, max_epochs=50,
                 dropout=0,
                 batch_size=32, device="cpu",
                 early_stopping_enabled=True, early_stopping_patience=5, early_stopping_threshold=1e-4,
                 reg_l2=0, reg_l1=0,
                 callbacks=None, **kwargs):

        if callbacks is None:
            callbacks = [InputShapeSetter(), EpochCounter()]
            if early_stopping_enabled:
                callbacks.append(EarlyStopping(patience=early_stopping_patience, threshold=early_stopping_threshold))

        if device == "cuda":
            if cuda.is_available():
                logger.info("Will use CUDA devices")
            else:
                device = "cpu"
                logger.warning("No available CUDA device found, falling back to CPU")

        self.reg_l1 = reg_l1

        super(DKUNeuralNetRegressor, self).__init__(
            module,
            optimizer=optimizer,
            criterion=criterion,
            module__dropout=kwargs.pop("module__dropout", dropout),
            max_epochs=max_epochs,
            batch_size=batch_size,
            optimizer__weight_decay=kwargs.pop("optimizer__weight_decay", reg_l2),
            predict_nonlinearity=kwargs.pop("predict_nonlinearity", 'auto'),  # pop needed because of the cloning of estimators leading to two declarations of the value in the init
            warm_start=kwargs.pop("warm_start", False),
            device=device,
            iterator_train__shuffle=kwargs.pop("iterator_train__shuffle", True),  # Shuffle training data on each epoch
            callbacks=callbacks,
            **self.add_prefix_to_dict(kwargs)
        )

    def fit(self, X, y, **fit_params):
        # This is needed by skorch :
        # The target data shouldn't be 1-dimensional but instead have 2 dimensions,
        # with the second dimension having the same size as the number of regression targets
        y = y.values.reshape(-1, 1)
        return super(DKUNeuralNetRegressor, self).fit(X, y, **fit_params)

    def predict(self, X):
        # skorch returns predictions of shape (n_samples, 1)
        # we want to return predictions of shape (n_samples,)
        prediction = super(DKUNeuralNetRegressor, self).predict(X)
        prediction = prediction.reshape(-1)
        return prediction

    def get_loss(self, y_pred, y_true, X=None, training=False):
        loss = super(DKUNeuralNetRegressor, self).get_loss(y_pred, y_true, X=X, training=training)
        if self.reg_l1 > 0:
            # L1 regularization only on weights
            # param_value is a tensor, so we use abs() and sum() which are torch tensor methods
            loss += self.reg_l1 * sum(param_value.abs().sum() for param_name, param_value in self.module_.named_parameters() if "weight" in param_name)
        return loss

    def __repr__(self):
        return "{}\nparams: {}".format(self.__class__.__name__, pformat(self.get_params(deep=True)))

    def get_params(self, deep=True, **kwargs):
        return self.remove_prefix_to_dict(super(DKUNeuralNetRegressor, self).get_params(deep=deep, **kwargs))

    def set_params(self, **kwargs):
        return super(DKUNeuralNetRegressor, self).set_params(**self.add_prefix_to_dict(kwargs))


class DKUNeuralNetClassifier(NeuralNetClassifier, _ParamsRemapper):
    """
    Skorch takes an un-instantiated pytorch module, and a lot of other arguments.
    - The arguments that are not prefixed are used directly by skorch (max_epochs, optimizer...)
    - The arguments prefixed with `module__` will be passed to the __init__ method of the module (module__hidden_layers, module__units...)
    - The arguments for the callbacks are prefixed with `callbacks__<callback_name>__` (callbacks__EarlyStopping__...)
    """
    def __init__(self, module=DeepNeuralNetworkModel,
                 optimizer=optim.SGD, criterion=nn.CrossEntropyLoss, max_epochs=50,
                 dropout=0,
                 batch_size=32, device="cpu",
                 early_stopping_enabled=True, early_stopping_patience=5, early_stopping_threshold=1e-4,
                 reg_l2=0, reg_l1=0,
                 callbacks=None, **kwargs):
        # Used later for class encoding
        # self._classes will contain the classes used by the label encoder to do the inverse_transform()
        # this is important so that we can do predict on a cloned estimator
        self._le = LabelEncoder()
        self._classes = None

        if callbacks is None:
            callbacks = [InputShapeSetter(), OutputShapeSetter(), EpochCounter()]
            if early_stopping_enabled:
                callbacks.append(EarlyStopping(patience=early_stopping_patience, threshold=early_stopping_threshold))

        self.reg_l1 = reg_l1

        super(DKUNeuralNetClassifier, self).__init__(
            module,
            optimizer=optimizer,
            criterion=criterion,
            module__dropout=kwargs.pop("module__dropout", dropout),
            max_epochs=max_epochs,
            batch_size=batch_size,
            optimizer__weight_decay=kwargs.pop("optimizer__weight_decay", reg_l2),
            predict_nonlinearity=kwargs.pop("predict_nonlinearity", 'auto'),  # pop needed because of the cloning of estimators leading to two declarations of the value in the init
            warm_start=kwargs.pop("warm_start", False),
            device=device,
            iterator_train__shuffle=kwargs.pop("iterator_train__shuffle", True),  # Shuffle training data on each epoch
            callbacks=callbacks,
            **self.add_prefix_to_dict(kwargs)
        )

    def fit(self, X, y, **fit_params):
        self._le.fit(y)
        _y = self._le.transform(y)
        self._classes = self._le.classes_
        return super(DKUNeuralNetClassifier, self).fit(X, _y, **fit_params)

    def predict(self, X):
        result = super(DKUNeuralNetClassifier, self).predict(X)
        # Make sure that the label encoder has the right classes
        self._le.classes_ = self._classes
        return self._le.inverse_transform(result)

    def get_loss(self, y_pred, y_true, X=None, training=False):
        loss = super(DKUNeuralNetClassifier, self).get_loss(y_pred, y_true, X=X, training=training)
        if self.reg_l1 > 0:
            # L1 regularization only on weights
            # param_value is a tensor, so we use abs() and sum() which are torch tensor methods
            loss += self.reg_l1 * sum(param_value.abs().sum() for param_name, param_value in self.module_.named_parameters() if "weight" in param_name)
        return loss

    def __repr__(self):
        return "{}\nparams: {}".format(self.__class__.__name__, pformat(self.get_params(deep=True)))

    def get_params(self, deep=True, **kwargs):
        return self.remove_prefix_to_dict(super(DKUNeuralNetClassifier, self).get_params(deep=deep, **kwargs))

    def set_params(self, **kwargs):
        return super(DKUNeuralNetClassifier, self).set_params(**self.add_prefix_to_dict(kwargs))
