/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.metrics.probes;

import com.dataiku.dip.CodedRuntimeException;
import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.datasets.DatasetCodes;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.meanings.IBasicMeaningsService;
import com.dataiku.dip.metrics.Metric;
import com.dataiku.dip.metrics.MetricComputer;
import com.dataiku.dip.metrics.MetricMetadata;
import com.dataiku.dip.metrics.MetricTargetType;
import com.dataiku.dip.metrics.probes.Probe;
import com.dataiku.dip.metrics.probes.ProbeConfiguration;
import com.dataiku.dip.metrics.probes.ProbeMetadata;
import com.dataiku.dip.metrics.probes.ProbeType;
import com.dataiku.j2ts.annotations.UIModel;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

public class ModelPerformanceProbeType
extends ProbeType {
    public static final String TYPE = "model_perf";
    private static final Map<ModelPerformanceMetrics, MetricMetadata> metadataPerType = Maps.newHashMap();

    @Override
    public List<MetricComputer> getComputers(IBasicMeaningsService basicMeaningsService) {
        return Lists.newArrayList();
    }

    public ModelPerformanceProbeType() {
        this.type = TYPE;
    }

    @Override
    public List<Metric> getMetricsToCompute(Probe probe, Object object, MetricTargetType objectType, boolean forDisplay) {
        return Lists.newArrayList();
    }

    @Override
    public ProbeConfiguration buildFullConfiguration(List<SchemaColumn> column, Probe probe) {
        return null;
    }

    @Override
    public ProbeType trimForSave() {
        return new ModelPerformanceProbeType();
    }

    @Override
    public Class<? extends ProbeConfiguration> getParamsClazz() {
        return ModelPerformanceProbeConfiguration.class;
    }

    @Override
    public ProbeMetadata getMeta() {
        return new ProbeMetadata().withLevel(0).withName("Model performance");
    }

    @Override
    public boolean isUserSelectedProbe() {
        return false;
    }

    static {
        for (ModelPerformanceMetrics metric : ModelPerformanceMetrics.values()) {
            metadataPerType.put(metric, new MetricMetadata().withName(metric.displayName).withFormat("longReadableNumber"));
        }
    }

    public static class ModelPerformanceProbeConfiguration
    implements ProbeConfiguration {
    }

    public static enum ModelPerformanceMetrics {
        R2("R2"),
        EVS("EVS"),
        MAPE("MAPE"),
        MAE("MAE"),
        MSE("MSE"),
        RMSE("RMSE"),
        RMSLE("RMSLE"),
        PEARSON("Pearson"),
        CUSTOMOPTIMISATIONSCORE("Custom optimisation score"),
        F1("F1"),
        RECALL("Recall"),
        PRECISION("Precision"),
        ACCURACY("Accuracy"),
        LOGLOSS("Log. loss"),
        AUC("AUC"),
        LIFT("Lift"),
        CALIBRATIONLOSS("Calibration Loss"),
        AVERAGE_PRECISION("Average Precision"),
        CUSTOM("Custom Metric"),
        PASSING_ASSERTIONS_RATIO("Passing assertions ratio"),
        ASSERTION_NB_MATCHING_ROWS("Rows matching criteria", true),
        ASSERTION_NB_DROPPED_ROWS("Dropped rows", true),
        ASSERTION_VALID_RATIO("Valid ratio", true),
        AVERAGE_PRECISION_IOU50("Average Precision (IoU=0.50)"),
        AVERAGE_PRECISION_IOU75("Average Precision (IoU=0.75)"),
        AVERAGE_PRECISION_ALL_IOU("Average Precision (all IoUs)"),
        MASE("MASE"),
        MEAN_ABSOLUTE_QUANTILE_LOSS("Mean Absolute Quantile Loss"),
        MEAN_WEIGHTED_QUANTILE_LOSS("Mean Weighted Quantile Loss"),
        MSIS("MSIS"),
        ND("ND"),
        SMAPE("sMAPE"),
        WORST_MASE("Worst MASE"),
        WORST_MAPE("Worst MAPE"),
        WORST_SMAPE("Worst sMAPE"),
        WORST_MSE("Worst MSE"),
        WORST_MSIS("Worst MSIS"),
        WORST_MAE("Worst MAE"),
        AUUC("AUUC"),
        QINI("QINI"),
        NET_UPLIFT("Net uplift"),
        DATA_DRIFT("Data Drift"),
        DATA_DRIFT_PVALUE("Data Drift p-value"),
        MIN_KS("Minimum KS"),
        MIN_CHISQUARE("Minimum Chi-square"),
        MAX_PSI("Maximum PSI"),
        UNIVARIATE_DRIFT("Univariate drift"),
        PREDICTION_DRIFT_KS("Prediction Drift KS"),
        PREDICTION_DRIFT_CHISQUARE("Prediction Drift Chi-square"),
        PREDICTION_DRIFT_PSI("Prediction Drift PSI"),
        EMBEDDING_DRIFT("Embedding drift"),
        IMAGE_QUALITY_DRIFT("Image Quality drift");

        final boolean isPerAssertion;
        final String displayName;

        private ModelPerformanceMetrics(String displayName) {
            this.displayName = displayName;
            this.isPerAssertion = false;
        }

        private ModelPerformanceMetrics(String displayName, boolean isPerAssertion) {
            this.displayName = displayName;
            this.isPerAssertion = isPerAssertion;
        }

        public static ModelPerformanceMetrics fromString(String metricType) {
            if (metricType.equals("CUSTOMSCORE")) {
                return CUSTOMOPTIMISATIONSCORE;
            }
            return ModelPerformanceMetrics.valueOf(metricType);
        }
    }

    public static class ModelPerformanceMetricSerializer
    implements Metric.MetricIdSerializer {
        @Override
        public String serializeMetric(Metric metric) {
            if (!(metric instanceof ModelPerformanceMetric)) {
                throw new CodedRuntimeException((InfoMessage.MessageCode)DatasetCodes.ERR_DATASET_INVALID_METRIC_IDENTIFIER, "Probe type " + this.getClass().getSimpleName() + " does not handle " + metric.getClass().getSimpleName());
            }
            ModelPerformanceMetric modelPerformanceMetric = (ModelPerformanceMetric)metric;
            if (modelPerformanceMetric.metricType.isPerAssertion) {
                return Metric.buildMetricIdFromParts(ModelPerformanceProbeType.TYPE, modelPerformanceMetric.getMetricType().name(), modelPerformanceMetric.assertionName);
            }
            if (modelPerformanceMetric.isCustomEvaluationMetric) {
                return Metric.buildMetricIdFromParts(ModelPerformanceProbeType.TYPE, modelPerformanceMetric.getMetricType().name(), modelPerformanceMetric.customMetricName);
            }
            if (modelPerformanceMetric.isUnivariateDriftMetric) {
                return Metric.buildMetricIdFromParts(ModelPerformanceProbeType.TYPE, modelPerformanceMetric.getMetricType().name(), modelPerformanceMetric.featureName + "_" + modelPerformanceMetric.univariateDriftMetric.testName);
            }
            if (modelPerformanceMetric.isEmbeddingDriftMetric) {
                return Metric.buildMetricIdFromParts(ModelPerformanceProbeType.TYPE, modelPerformanceMetric.getMetricType().name(), modelPerformanceMetric.featureName + "_" + modelPerformanceMetric.embeddingDriftMetric.metricName);
            }
            if (modelPerformanceMetric.isImageQualityDriftMetric) {
                return Metric.buildMetricIdFromParts(ModelPerformanceProbeType.TYPE, modelPerformanceMetric.getMetricType().name(), modelPerformanceMetric.featureName + "_" + modelPerformanceMetric.imageQualityDriftMetric.metricName);
            }
            return Metric.buildMetricIdFromParts(ModelPerformanceProbeType.TYPE, modelPerformanceMetric.getMetricType().name());
        }

        @Override
        public Metric deserializeMetric(String metricId) {
            List<String> parts = Metric.buildPartsFromMetricId(metricId);
            if (parts.get(0).equals(ModelPerformanceProbeType.TYPE)) {
                ModelPerformanceMetrics metricType = ModelPerformanceMetrics.fromString(parts.get(1));
                if (parts.size() == 2) {
                    return new ModelPerformanceMetric(metricType);
                }
                if (parts.size() == 3) {
                    String metricName = parts.get(2);
                    if (metricType == ModelPerformanceMetrics.CUSTOM) {
                        return new ModelPerformanceMetric(metricName);
                    }
                    if (metricType == ModelPerformanceMetrics.UNIVARIATE_DRIFT) {
                        UnivariateDriftMetric driftMetric = UnivariateDriftMetric.getUnivariateDriftMetric(metricName);
                        return new ModelPerformanceMetric(driftMetric, metricName.substring(0, metricName.length() - driftMetric.testName.length() - 1));
                    }
                    if (metricType == ModelPerformanceMetrics.EMBEDDING_DRIFT) {
                        EmbeddingDriftMetric embeddingDriftMetric = EmbeddingDriftMetric.getEmbeddingDriftMetric(metricName);
                        return new ModelPerformanceMetric(embeddingDriftMetric, metricName.substring(0, metricName.length() - embeddingDriftMetric.metricName.length() - 1));
                    }
                    if (metricType == ModelPerformanceMetrics.IMAGE_QUALITY_DRIFT) {
                        ImageQualityDriftMetric imageQualityDriftMetric = ImageQualityDriftMetric.getImageQualityDriftMetric(metricName);
                        return new ModelPerformanceMetric(imageQualityDriftMetric, metricName.substring(0, metricName.length() - imageQualityDriftMetric.metricName.length() - 1));
                    }
                    return new ModelPerformanceMetric(metricType, metricName);
                }
            }
            throw new CodedRuntimeException((InfoMessage.MessageCode)DatasetCodes.ERR_DATASET_INVALID_METRIC_IDENTIFIER, "Probe type " + this.getClass().getSimpleName() + " does not handle " + metricId);
        }
    }

    @UIModel
    public static class ModelPerformanceMetric
    extends Metric {
        public ModelPerformanceMetrics metricType;
        public String assertionName;
        public String featureName;
        public boolean isCustomEvaluationMetric = false;
        public String customMetricName;
        public boolean isUnivariateDriftMetric = false;
        public UnivariateDriftMetric univariateDriftMetric;
        public boolean isEmbeddingDriftMetric = false;
        public EmbeddingDriftMetric embeddingDriftMetric;
        public boolean isImageQualityDriftMetric = false;
        public ImageQualityDriftMetric imageQualityDriftMetric;

        public ModelPerformanceMetric(ModelPerformanceMetrics metricType) {
            super(ModelPerformanceProbeType.TYPE, ModelPerformanceMetric.getDataTypeFromMetricType(metricType));
            assert (!metricType.isPerAssertion) : "Invalid metricType: " + metricType.toString() + ", per assertion metric is not accepted without assertionName";
            this.metricType = metricType;
            this.id = Metric.serializeMetric(this);
        }

        public ModelPerformanceMetric(String customMetricName) {
            super(ModelPerformanceProbeType.TYPE, ModelPerformanceMetric.getDataTypeFromMetricType(ModelPerformanceMetrics.CUSTOM));
            this.metricType = ModelPerformanceMetrics.CUSTOM;
            this.customMetricName = customMetricName;
            this.isCustomEvaluationMetric = true;
            this.id = Metric.serializeMetric(this);
        }

        public ModelPerformanceMetric(ModelPerformanceMetrics metricType, String assertionName) {
            super(ModelPerformanceProbeType.TYPE, ModelPerformanceMetric.getDataTypeFromMetricType(metricType));
            assert (metricType.isPerAssertion) : "Invalid metricType: " + metricType.toString() + ", expected a per assertion metric";
            this.metricType = metricType;
            this.assertionName = assertionName;
            this.id = Metric.serializeMetric(this);
        }

        public ModelPerformanceMetric(UnivariateDriftMetric univariateDriftMetric, String featureName) {
            super(ModelPerformanceProbeType.TYPE, ModelPerformanceMetric.getDataTypeFromMetricType(ModelPerformanceMetrics.UNIVARIATE_DRIFT));
            this.metricType = ModelPerformanceMetrics.UNIVARIATE_DRIFT;
            this.isUnivariateDriftMetric = true;
            this.univariateDriftMetric = univariateDriftMetric;
            this.featureName = featureName;
            this.id = Metric.serializeMetric(this);
        }

        public ModelPerformanceMetric(EmbeddingDriftMetric embeddingDriftMetric, String featureName) {
            super(ModelPerformanceProbeType.TYPE, ModelPerformanceMetric.getDataTypeFromMetricType(ModelPerformanceMetrics.EMBEDDING_DRIFT));
            this.metricType = ModelPerformanceMetrics.EMBEDDING_DRIFT;
            this.isEmbeddingDriftMetric = true;
            this.embeddingDriftMetric = embeddingDriftMetric;
            this.featureName = featureName;
            this.id = Metric.serializeMetric(this);
        }

        public ModelPerformanceMetric(ImageQualityDriftMetric imageQualityDriftMetric, String featureName) {
            super(ModelPerformanceProbeType.TYPE, ModelPerformanceMetric.getDataTypeFromMetricType(ModelPerformanceMetrics.IMAGE_QUALITY_DRIFT));
            this.metricType = ModelPerformanceMetrics.IMAGE_QUALITY_DRIFT;
            this.isImageQualityDriftMetric = true;
            this.imageQualityDriftMetric = imageQualityDriftMetric;
            this.featureName = featureName;
            this.id = Metric.serializeMetric(this);
        }

        private static Type getDataTypeFromMetricType(ModelPerformanceMetrics metricType) {
            return Type.DOUBLE;
        }

        public ModelPerformanceMetrics getMetricType() {
            return this.metricType;
        }

        @Override
        public String getColumn() {
            return null;
        }

        @Override
        public String getColumnInvariantId(String placeholder) {
            return this.getId();
        }

        @Override
        public MetricMetadata getMeta() {
            if (!this.metricType.isPerAssertion) {
                return metadataPerType.get((Object)this.metricType);
            }
            MetricMetadata typeMetaData = metadataPerType.get((Object)this.metricType);
            MetricMetadata metadata = new MetricMetadata(typeMetaData);
            metadata.name = typeMetaData.name + ": " + this.assertionName;
            return metadata;
        }

        @Override
        public Probe getMatchingProbe(List<Probe> probes) {
            return new Probe(this.getType()).withConfiguration(new ModelPerformanceProbeConfiguration()).withMeta(ProbeType.getProbeType(ModelPerformanceProbeType.TYPE).getMeta());
        }
    }

    public static enum ImageQualityDriftMetric {
        MEAN_RED("Mean Red KS"),
        MEAN_GREEN("Mean Green KS"),
        MEAN_BLUE("Mean Blue KS"),
        MEAN_SATURATION("Mean saturation KS"),
        RMS_CONTRAST("RMS contrast KS"),
        LAPLACIAN_VAR("Laplacian variance KS"),
        TENENGRAD("Tenengrad KS"),
        ENTROPY("Entropy KS"),
        EDGE_DENSITY("Edge density KS"),
        AREA("Area KS"),
        ASPECT_RATIO("Aspect ratio KS");

        public final String metricName;

        private ImageQualityDriftMetric(String metricName) {
            this.metricName = metricName;
        }

        public static ImageQualityDriftMetric getImageQualityDriftMetric(String imageQualityDriftMetricName) {
            return Arrays.stream(ImageQualityDriftMetric.values()).filter(metric -> imageQualityDriftMetricName.endsWith(metric.metricName)).findFirst().orElseThrow(() -> new IllegalArgumentException("Unhandled metric name: " + imageQualityDriftMetricName));
        }
    }

    public static enum EmbeddingDriftMetric {
        EUCLIDIAN_DISTANCE("Euclidian distance"),
        COSINE_SIMILARITY("Cosine similarity"),
        CLASSIFIER_GINI("Classifier Gini");

        final String metricName;

        private EmbeddingDriftMetric(String metricName) {
            this.metricName = metricName;
        }

        public static EmbeddingDriftMetric getEmbeddingDriftMetric(String embeddingDriftMetricName) {
            if (embeddingDriftMetricName.endsWith(EmbeddingDriftMetric.EUCLIDIAN_DISTANCE.metricName)) {
                return EUCLIDIAN_DISTANCE;
            }
            if (embeddingDriftMetricName.endsWith(EmbeddingDriftMetric.COSINE_SIMILARITY.metricName)) {
                return COSINE_SIMILARITY;
            }
            if (embeddingDriftMetricName.endsWith(EmbeddingDriftMetric.CLASSIFIER_GINI.metricName)) {
                return CLASSIFIER_GINI;
            }
            throw new IllegalArgumentException("Unhandled metric name: " + embeddingDriftMetricName);
        }
    }

    public static enum UnivariateDriftMetric {
        KS("KS"),
        CHISQUARE("Chi-square"),
        PSI("PSI");

        final String testName;

        private UnivariateDriftMetric(String testName) {
            this.testName = testName;
        }

        public static UnivariateDriftMetric getUnivariateDriftMetric(String univariateDriftMetricName) {
            if (univariateDriftMetricName.endsWith(UnivariateDriftMetric.PSI.testName)) {
                return PSI;
            }
            if (univariateDriftMetricName.endsWith(UnivariateDriftMetric.KS.testName)) {
                return KS;
            }
            if (univariateDriftMetricName.endsWith(UnivariateDriftMetric.CHISQUARE.testName)) {
                return CHISQUARE;
            }
            throw new IllegalArgumentException("Unhandled metric name: " + univariateDriftMetricName);
        }
    }
}

