import logging
import torch

logger = logging.getLogger(__name__)


def init_optimizer(model, modeling_params):
    params = [p for p in model.parameters() if p.requires_grad]
    if modeling_params["optimizer"] == "ADAM":
        optimizer = torch.optim.Adam(params, lr=modeling_params["learningRate"],
                                     weight_decay=modeling_params["weightDecay"])
    elif modeling_params["optimizer"] == "SGD":
        optimizer = torch.optim.SGD(params, lr=modeling_params["learningRate"],
                                    weight_decay=modeling_params["weightDecay"])
    elif modeling_params["optimizer"] == "RMSPROP":
        optimizer = torch.optim.RMSprop(params, lr=modeling_params["learningRate"],
                                        weight_decay=modeling_params["weightDecay"])
    elif modeling_params["optimizer"] == "ADAMAX":
        optimizer = torch.optim.Adamax(params, lr=modeling_params["learningRate"],
                                       weight_decay=modeling_params["weightDecay"])
    elif modeling_params["optimizer"] == "ADAGRAD":
        optimizer = torch.optim.Adagrad(params, lr=modeling_params["learningRate"],
                                        weight_decay=modeling_params["weightDecay"])
    elif modeling_params["optimizer"] == "ADADELTA":
        optimizer = torch.optim.Adadelta(params, lr=modeling_params["learningRate"],
                                         weight_decay=modeling_params["weightDecay"])
    else:
        raise ValueError("Unknown optimizer: {}".format(modeling_params["optimizer"]))

    logger.info("Optimizer initialized with following params {}".format(optimizer.state_dict().get("param_groups")))
    return optimizer


class LearningRateStrategy(object):

    def __init__(self, modeling_params, optimizer):
        self.optimizer = optimizer
        self.modeling_params = modeling_params

        self._warmup_lr_scheduler = None
        self._on_epoch_end_lr_scheduler = None

    def on_train_start(self, num_batches_per_epoch):
        """ initialize both learning rate schedulers """
        self._init_warmup_lr_scheduler(num_batches_per_epoch)
        self._init_lr_scheduler()

    def on_batch_end(self):
        if self._warmup_lr_scheduler is not None:
            self._warmup_lr_scheduler.step()

    def on_epoch_end(self, epoch_index, val_metric_gib):
        if epoch_index == 0:
            # End of 1st epoch, deactivate the warmup scheduler for next epochs
            self._warmup_lr_scheduler = None

        if self._on_epoch_end_lr_scheduler is not None:
            if self.modeling_params["lrScheduler"] == "PLATEAU":
                self._on_epoch_end_lr_scheduler.step(val_metric_gib)
            else:
                self._on_epoch_end_lr_scheduler.step()

    def _init_warmup_lr_scheduler(self, total_nb_batches):
        def f(batch_idx):
            """ Return a multiplicative factor used to multiply **initial LR value** at each step of the LR
                scheduler (and on training start)
                eg: if 2 batches per epoch:
                    epoch0: - train the 1st batch with initial_lr *1/3.
                            - train 2nd batch with initial_lr *2/3
                    epoch1: train 1st batch with full initial lr (*3/3)
            """
            factor = 1 if batch_idx >= warmup_batches else 1. * (batch_idx + 1) / (warmup_batches + 1)
            return factor

        warmup_batches = min(1000, total_nb_batches)
        logger.info("Using a warmup LR scheduler for {} batches (epoch 0 only).".format(warmup_batches))
        self._warmup_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, f)

    def _init_lr_scheduler(self):
        if self.modeling_params["lrScheduler"] == "PLATEAU":
            # Note: we make sure on plateau LR patience is always smaller than early stopping patience
            # to avoid exiting training loop without even trying to reduce lr.
            if self.modeling_params["earlyStopping"]["enabled"]:
                lr_patience = int(self.modeling_params["earlyStopping"]["patience"] / 2.0)
            else:
                lr_patience = 10
            logger.info("LR scheduler patience chosen for Plateau strategy: {} epochs".format(lr_patience))

            self._on_epoch_end_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                                                         patience=lr_patience,
                                                                                         mode="max")

        elif self.modeling_params["lrScheduler"] == "STEP":
            # step_size is usually 1/3 of nb of epochs. It's meant to be a larger reduction (vs exponentialLR),
            # so it should not happen too fast. Training might be stopped before the big step down
            # with stepLR due to early stopping, but we probably have to keep it like this.
            step_size = max(int(self.modeling_params["epochs"] / 3.), 1)
            self._on_epoch_end_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=step_size)

        elif self.modeling_params["lrScheduler"] == "EXPONENTIAL":
            self._on_epoch_end_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.95)

        else:
            raise ValueError("Unknown LR scheduler {}".format(self.modeling_params["lrScheduler"]))

        logger.info("LR scheduler initialized with: {}".format(self._on_epoch_end_lr_scheduler.state_dict()))
