import logging
import numpy as np

from dataiku.doctor.timeseries.models.gluonts.base_estimator import DkuGluonTSDeepLearningEstimator
from dataiku.doctor.utils.gpu_execution import get_single_gpu_id_from_gpu_device, GluonTSMXNetGPUCapability


# Handle missing libraries for MxNet (SC-98195)
try:
    import mxnet as mx
    from gluonts.mx.distribution import StudentTOutput
    from gluonts.mx.distribution import GaussianOutput
    from gluonts.mx.distribution import NegativeBinomialOutput
    from gluonts.mx.trainer import Trainer
except OSError as e:
    if "libcudart" in e.args[0]:
        raise ImportError('CUDA Runtime Library not found: required for GPU execution of MxNet. Contact your administrator to install it')
    if "libcudnn" in e.args[0]:
        raise ImportError('CuDNN library not found: required for GPU execution of MxNet. Contact your administrator to install it')
    if "libnccl" in e.args[0]:
        raise ImportError('NCCL library not found: required for GPU execution of MxNet. Contact your administrator to install it')
    raise ImportError(e.args[0])

DISTR_OUTPUT = {
    "StudentTOutput": StudentTOutput(),
    "GaussianOutput": GaussianOutput(),
    "NegativeBinomialOutput": NegativeBinomialOutput(),
}

logger = logging.getLogger(__name__)


class DkuGluonTSMXNetDeepLearningEstimator(DkuGluonTSDeepLearningEstimator):

    @staticmethod
    def get_distr_output_class(distr_output):
        return DISTR_OUTPUT.get(distr_output, StudentTOutput())

    def _prepare_predict(self):
        super(DkuGluonTSMXNetDeepLearningEstimator, self)._prepare_predict()
        mx.random.seed(self.seed, self.predictor.ctx)

    def get_trainer(self, train_data):
        if self.auto_num_batches_per_epoch:
            self.num_batches_per_epoch = self._compute_auto_num_batches_per_epoch(train_data)

        device = GluonTSMXNetGPUCapability.get_device(self.gpu_config)
        if device == 'cpu':
            mxnet_context = mx.context.cpu()
        else:
            mxnet_context = mx.context.gpu(get_single_gpu_id_from_gpu_device(device))

        mx.random.seed(self.seed, mxnet_context)

        return Trainer(
            ctx=mxnet_context, learning_rate=self.learning_rate, epochs=self.epochs,
            num_batches_per_epoch=self.num_batches_per_epoch, hybridize=device == 'cpu' # Known MXNet issue on GPU, see https://github.com/awslabs/gluonts/issues/2264
        )

