import gluonts

from dataiku.base.utils import package_is_at_least


def instantiate_simple_feed_forward_estimator(
    trainer,
    frequency,
    prediction_length,
    batch_size,
    context_length,
    num_hidden_dimensions,
    num_parallel_samples,
    batch_normalization,
    mean_scaling,
    distr_output,
):
    if package_is_at_least(gluonts, "0.12.0"):
        from gluonts.mx.model.simple_feedforward import SimpleFeedForwardEstimator
    else:
        from gluonts.model.simple_feedforward import SimpleFeedForwardEstimator
    if package_is_at_least(gluonts, "0.10.0"):
        # Starting from 0.10.0 the freq argument has been removed
        return SimpleFeedForwardEstimator(
            trainer=trainer,
            prediction_length=prediction_length,
            batch_size=batch_size,
            context_length=context_length,
            num_hidden_dimensions=num_hidden_dimensions,
            num_parallel_samples=num_parallel_samples,
            batch_normalization=batch_normalization,
            mean_scaling=mean_scaling,
            distr_output=distr_output,
        )
    else:
        return SimpleFeedForwardEstimator(
            trainer=trainer,
            freq=frequency,
            prediction_length=prediction_length,
            batch_size=batch_size,
            context_length=context_length,
            num_hidden_dimensions=num_hidden_dimensions,
            num_parallel_samples=num_parallel_samples,
            batch_normalization=batch_normalization,
            mean_scaling=mean_scaling,
            distr_output=distr_output,
        )



