/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.ParameterChecks;
import com.dataiku.dip.analysis.model.core.PreTrainModelingParams;
import com.dataiku.dip.analysis.model.prediction.DeepHubMetricParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonDeserializationContext;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParseException;
import com.google.gson.JsonSerializationContext;
import java.lang.reflect.Type;

public abstract class DeepHubPreTrainModelingParams
implements PreTrainModelingParams {
    public boolean dummy = false;
    public PredictionMLTask.PredictionType type;
    public DeepHubMetricParams metrics;
    public DeepHubModelOptimizationSplitParams modelOptimizationSplitParams = new DeepHubModelOptimizationSplitParams();
    public double learningRate = 1.0E-4;
    public int nbFinetunedLayers = 0;
    public int perDeviceBatchSize = 2;
    public int epochs = 50;
    public double weightDecay = 0.0;
    public Optimizer optimizer = Optimizer.ADAM;
    public LrScheduler lrScheduler = LrScheduler.PLATEAU;
    public EarlyStopping earlyStopping = new EarlyStopping();
    public int processCountPerNode = 1;
    public ImageAugmentationParams augmentationParams = new ImageAugmentationParams();
    public boolean enableParallelDataLoading;
    public int numWorkers = 2;

    public static DeepHubPreTrainModelingParams build(PredictionMLTask.PredictionType predictionType) {
        switch (predictionType) {
            case DEEP_HUB_IMAGE_OBJECT_DETECTION: {
                return new ObjectDetectionPreTrainModelingParams();
            }
            case DEEP_HUB_IMAGE_CLASSIFICATION: {
                return new ImageClassificationPreTrainModelingParams();
            }
        }
        throw new IllegalArgumentException("Unsupported prediction type: " + String.valueOf((Object)predictionType));
    }

    public void validateParameters(PredictionMLTask.DeepHubPredictionMLTask task, ParameterChecks checks) {
        ErrorContext.check((this.learningRate > 0.0 ? 1 : 0) != 0, (String)"Learning rate must be > 0");
        ErrorContext.check((this.perDeviceBatchSize > 0 ? 1 : 0) != 0, (String)"perDeviceBatchSize must be > 0");
        ErrorContext.check((this.epochs > 0 ? 1 : 0) != 0, (String)"Number of epochs must be > 0");
        ErrorContext.check((this.weightDecay >= 0.0 ? 1 : 0) != 0, (String)"Weight decay must be >= 0");
        ErrorContext.check((this.earlyStopping.minDelta >= 0.0 ? 1 : 0) != 0, (String)"Early stopping min_delta must be >= 0");
        ErrorContext.check((this.earlyStopping.patience >= 0 ? 1 : 0) != 0, (String)"Early stopping patience must be >= 0");
        ErrorContext.check((this.nbFinetunedLayers >= 0 ? 1 : 0) != 0, (String)"Number of fine-tuned layers must be >= 0");
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription() {
        return null;
    }

    @Override
    public String getEvaluationMetricName() {
        return this.metrics.evaluationMetric.toString();
    }

    static {
        JSON.registerAdapter(DeepHubPreTrainModelingParams.class, (Object)new JSON.Adapter<DeepHubPreTrainModelingParams>(){

            public DeepHubPreTrainModelingParams deserialize(JsonElement jsonElement, Type scriptType, JsonDeserializationContext jsonDeserializationContext) throws JsonParseException {
                JsonObject jsonObj = jsonElement.getAsJsonObject();
                String type = jsonObj.get("type").getAsString();
                if (PredictionMLTask.PredictionType.DEEP_HUB_IMAGE_OBJECT_DETECTION.name().equals(type)) {
                    return (DeepHubPreTrainModelingParams)jsonDeserializationContext.deserialize(jsonElement, ObjectDetectionPreTrainModelingParams.class);
                }
                if (PredictionMLTask.PredictionType.DEEP_HUB_IMAGE_CLASSIFICATION.name().equals(type)) {
                    return (DeepHubPreTrainModelingParams)jsonDeserializationContext.deserialize(jsonElement, ImageClassificationPreTrainModelingParams.class);
                }
                throw new IllegalArgumentException("Unsupported type:" + type);
            }

            public JsonElement serialize(DeepHubPreTrainModelingParams params, Type type, JsonSerializationContext ctx) {
                return ctx.serialize((Object)params);
            }
        });
    }

    public static class DeepHubModelOptimizationSplitParams {
        public float trainSplitRatio = 0.8f;
        public int seed = 1337;
    }

    public static enum Optimizer {
        ADAM,
        SGD,
        RMSPROP,
        ADAMAX,
        ADAGRAD,
        ADADELTA;

    }

    public static enum LrScheduler {
        PLATEAU,
        STEP,
        EXPONENTIAL;

    }

    static class EarlyStopping {
        boolean enabled = true;
        double minDelta = 0.0;
        int patience = 5;

        EarlyStopping() {
        }
    }

    public static class ImageAugmentationParams {
        ColorJitterParams colorJitter = new ColorJitterParams();
        AffineParams affine = new AffineParams();
        RandomCropParams crop = new RandomCropParams();
    }

    public static class ObjectDetectionPreTrainModelingParams
    extends DeepHubPreTrainModelingParams {
        public PretrainedModel pretrainedModel = PretrainedModel.FASTERRCNN;

        public ObjectDetectionPreTrainModelingParams() {
            this.type = PredictionMLTask.PredictionType.DEEP_HUB_IMAGE_OBJECT_DETECTION;
        }

        @Override
        public String generateName() {
            return "Object detection";
        }

        public static enum PretrainedModel {
            FASTERRCNN;

        }
    }

    public static class ImageClassificationPreTrainModelingParams
    extends DeepHubPreTrainModelingParams {
        public PretrainedModel pretrainedModel = PretrainedModel.EFFICIENTNET_B4;

        public ImageClassificationPreTrainModelingParams() {
            this.type = PredictionMLTask.PredictionType.DEEP_HUB_IMAGE_CLASSIFICATION;
            this.perDeviceBatchSize = 32;
        }

        @Override
        public String generateName() {
            return "Image classification";
        }

        public static enum PretrainedModel {
            EFFICIENTNET_B0,
            EFFICIENTNET_B4,
            EFFICIENTNET_B7;

        }
    }

    static class RandomCropParams
    extends ImageAugmentationParam {
        float minKeptRatio = 0.75f;
        boolean preserveAspectRatio = true;

        RandomCropParams() {
        }
    }

    static class ColorJitterParams
    extends ImageAugmentationParam {
        float contrast = 0.2f;
        float brightness = 0.2f;
        float hue = 0.2f;

        ColorJitterParams() {
        }
    }

    static class AffineParams {
        ImageAugmentationParam verticalFlip = new ImageAugmentationParam(false, 0.5f);
        ImageAugmentationParam horizontalFlip = new ImageAugmentationParam(false, 0.5f);
        RotateParams rotate = new RotateParams();

        AffineParams() {
        }
    }

    static class RotateParams
    extends ImageAugmentationParam {
        float maxRotation = 30.0f;

        RotateParams() {
        }
    }

    static class ImageAugmentationParam {
        boolean enabled = true;
        float probability = 1.0f;

        ImageAugmentationParam() {
        }

        ImageAugmentationParam(boolean enabled, float probability) {
            this.enabled = enabled;
            this.probability = probability;
        }
    }
}

