import torch
import pytorch_lightning as pl

from dataiku.doctor.timeseries.models.gluonts.base_estimator import DkuGluonTSDeepLearningEstimator

from gluonts.torch.distributions.studentT import StudentTOutput

from dataiku.doctor.utils.gpu_execution import get_default_gpu_config, GluonTSTorchGPUCapability

DISTR_OUTPUT = {
    "StudentTOutput": StudentTOutput()
}


class DkuGluonTSTorchDeepLearningEstimator(DkuGluonTSDeepLearningEstimator):

    def __init__(self,
                 frequency,
                 prediction_length,
                 time_variable,
                 target_variable,
                 timeseries_identifiers,
                 weight_decay,
                 use_timeseries_identifiers_as_features=False,
                 full_context=True,
                 batch_size=32,
                 epochs=10,
                 auto_num_batches_per_epoch=True,
                 num_batches_per_epoch=50,
                 learning_rate=.001,
                 seed=1337,
                 context_length=1,
                 distr_output="StudentTOutput",
                 gpu_config=get_default_gpu_config(),
                 monthly_day_alignment=None,
                 ):

        super(DkuGluonTSTorchDeepLearningEstimator, self).__init__(
            frequency=frequency,
            prediction_length=prediction_length,
            time_variable=time_variable,
            target_variable=target_variable,
            timeseries_identifiers=timeseries_identifiers,
            use_timeseries_identifiers_as_features=use_timeseries_identifiers_as_features,
            context_length=context_length,
            full_context=full_context,
            batch_size=batch_size,
            epochs=epochs,
            auto_num_batches_per_epoch=auto_num_batches_per_epoch,
            num_batches_per_epoch=num_batches_per_epoch,
            learning_rate=learning_rate,
            gpu_config=gpu_config,
            seed=seed,
            monthly_day_alignment=monthly_day_alignment,
        )

        self.distr_output = distr_output
        self.weight_decay = weight_decay

        if GluonTSTorchGPUCapability.should_use_gpu(self.gpu_config):
            self.accelerator = "gpu"
            self.devices = GluonTSTorchGPUCapability.get_lightning_devices(self.gpu_config)
        else:
            self.accelerator = "cpu"
            self.devices = "auto"

    def _prepare_predict(self):
        super(DkuGluonTSTorchDeepLearningEstimator, self)._prepare_predict()
        torch.manual_seed(self.seed)
        pl.seed_everything(self.seed, workers=True)

    @staticmethod
    def get_distr_output_class(distr_output):
        return DISTR_OUTPUT.get(distr_output, StudentTOutput())
