/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.coreservices;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLDiagnostics;
import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.clustering.PreTrainClusteringModelingParams;
import com.dataiku.dip.analysis.model.clustering.ResolvedClusteringCoreParams;
import com.dataiku.dip.analysis.model.core.AnalysisCoreParams;
import com.dataiku.dip.analysis.model.core.PreTrainModelingParams;
import com.dataiku.dip.analysis.model.prediction.BinaryClassificationModelPerf;
import com.dataiku.dip.analysis.model.prediction.MulticlassModelPerf;
import com.dataiku.dip.analysis.model.prediction.PartitionedModelExtract;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelIntrinsicPerf;
import com.dataiku.dip.analysis.model.prediction.RegressionModelPerf;
import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionCoreParams;
import com.dataiku.dip.reports.IReflectedEventsService;
import com.dataiku.dip.server.notifications.DSSEvent;
import com.dataiku.dip.server.notifications.backend.MLTaskStateChangedEvent;
import com.dataiku.dip.server.notifications.backend.ReflectedEventEvent;
import com.dataiku.dip.server.services.PubSubService;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.warnings.WarningsContext;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class TrainingSessionDetailsService {
    @Autowired
    private PubSubService pubSub;
    private static final Logger logger = Logger.getLogger((String)"dku.analysis.trainingdetails");

    public void publishPredictionSession(MLTaskLoc taskLoc, String sessionId, String initiator, AnalysisCoreParams cp, MLTask task) {
        this.publishEvent(new PredictionTrainingSessionDetails(), taskLoc, sessionId, initiator, cp, task);
    }

    public void publishClusteringSession(MLTaskLoc taskLoc, String sessionId, String initiator, AnalysisCoreParams cp, MLTask task) {
        this.publishEvent(new ClusteringTrainingSessionDetails(), taskLoc, sessionId, initiator, cp, task);
    }

    private void publishEvent(TrainingSessionDetails sessionDetails, MLTaskLoc taskLoc, String sessionId, String initiator, AnalysisCoreParams cp, MLTask task) {
        this.summarizeDetails(sessionDetails, taskLoc, sessionId);
        this.publishTrainingDetails(sessionDetails);
        this.publishSessionSummary(sessionDetails, sessionId, initiator, cp, task);
    }

    private void publishTrainingDetails(TrainingSessionDetails sessionDetails) {
        IReflectedEventsService.ReflectedEvent evt = new IReflectedEventsService.ReflectedEvent("mltask-train-done", JSON.toJsonObject((Object)sessionDetails));
        logger.info((Object)"Publishing mltask-train-done reflected event");
        this.pubSub.publish((DSSEvent)new ReflectedEventEvent(evt));
    }

    private void publishSessionSummary(TrainingSessionDetails sessionDetails, String sessionId, String initiator, AnalysisCoreParams cp, MLTask task) {
        TrainingSessionResultSummary results = sessionDetails.summarizeSessionResults();
        MLTaskStateChangedEvent event = new MLTaskStateChangedEvent(task, cp.projectKey, cp.id, sessionId, cp.inputDatasetSmartName, false, initiator, results);
        this.pubSub.publish(event);
    }

    private void summarizeDetails(TrainingSessionDetails details, MLTaskLoc taskLoc, String sessionId) {
        details.projectKey = DigestUtils.md5Hex((String)taskLoc.analysisProjectKey);
        details.analysisId = taskLoc.analysisId;
        details.mlTaskId = taskLoc.mlTaskId;
        for (FullModelId fmi : taskLoc.listModelIds(sessionId)) {
            if (fmi.isModelPartition()) continue;
            try {
                if (details.metric == null) {
                    details.update(fmi);
                }
                TrainingSessionModelInfo tsmi = details.newTrainingSessionModelInfo(fmi);
                switch (tsmi.state) {
                    case ABORTED: {
                        ++details.aborted;
                        break;
                    }
                    case DONE: {
                        ++details.done;
                        details.updateFeatureHandlings(fmi);
                        break;
                    }
                    case FAILED: {
                        ++details.failed;
                        break;
                    }
                    case PENDING: 
                    case RUNNING: {
                        ++details.running;
                    }
                }
                details.models.add(tsmi);
            }
            catch (Exception e) {
                logger.warn((Object)("Failed to get status for model " + String.valueOf(fmi)), (Throwable)e);
            }
        }
    }

    public static class PredictionTrainingSessionDetails
    extends TrainingSessionDetails {
        public PredictionMLTask.PredictionType predictionType;

        public PredictionTrainingSessionDetails() {
            this.mlTaskType = "PREDICTION";
        }

        @Override
        protected TrainingSessionModelInfo newTrainingSessionModelInfo(FullModelId fmi) throws IOException {
            PreTrainPredictionModelingParams modeling = fmi.parseModelFile("rmodeling_params.json", PreTrainPredictionModelingParams.class);
            ModelTrainInfo trainInfo = fmi.parseModelFile("train_info.json", ModelTrainInfo.class);
            PredictionTrainingSessionModelInfo tsmi = new PredictionTrainingSessionModelInfo(fmi, this.predictionType, trainInfo, modeling.algorithm);
            if (fmi.isPartitionedBaseModel() && fmi.getPartFile().exists()) {
                tsmi.partitionSummaries = this.addPartition(fmi, this.predictionType, modeling.algorithm);
            }
            return tsmi;
        }

        @Override
        protected void update(FullModelId fmi) throws IOException {
            PreTrainModelingParams modeling = fmi.parseModelFile("rmodeling_params.json", PreTrainModelingParams.class);
            this.metric = modeling.getEvaluationMetricName();
            ResolvedPredictionCoreParams pcp = (ResolvedPredictionCoreParams)fmi.getResolvedCoreParams();
            this.predictionType = pcp.prediction_type;
            this.diagnosticsSettings = pcp.diagnosticsSettings;
        }

        @Override
        protected void updateFeatureHandlings(FullModelId fmi) throws IOException {
            PredictionModelIntrinsicPerf iperf = FullModelId.getPredictionIntrinsicPerf(fmi.getModelFolder(), this.predictionType);
            if (iperf != null) {
                FeatureHandlingReport featureHandlingReport = new FeatureHandlingReport();
                if (iperf.rawImportance != null) {
                    featureHandlingReport.handlings = FeatureHandlingReport.parseFeatureHandlings(iperf.rawImportance.variables);
                    featureHandlingReport.scores = iperf.rawImportance.importances;
                } else if (iperf.lmCoefficients != null && iperf.lmCoefficients.rescaledCoefs != null) {
                    featureHandlingReport.handlings = FeatureHandlingReport.parseFeatureHandlings(iperf.lmCoefficients.variables);
                    double sumAbs = Arrays.stream(iperf.lmCoefficients.rescaledCoefs).map(Math::abs).sum();
                    featureHandlingReport.scores = sumAbs == 0.0 ? iperf.lmCoefficients.rescaledCoefs : Arrays.stream(iperf.lmCoefficients.rescaledCoefs).map(x -> Math.abs(x) / sumAbs).toArray();
                }
                this.featureHandlings.add(featureHandlingReport);
            }
        }

        private List<PredictionTrainingSessionModelInfo> addPartition(FullModelId fmi, PredictionMLTask.PredictionType predictionType, PreTrainPredictionModelingParams.Algorithm algorithm) {
            ArrayList<PredictionTrainingSessionModelInfo> partitionSummaries = new ArrayList<PredictionTrainingSessionModelInfo>();
            try {
                PartitionedModelExtract extract = fmi.getPartitionedModelExtract();
                for (PartitionedModelExtract.PartitionedModelSummary summary : extract.summaries.values()) {
                    FullModelId partitionFMI = FullModelId.parse(summary.snippet.fullModelId);
                    PredictionTrainingSessionModelInfo tsmi = new PredictionTrainingSessionModelInfo(partitionFMI, predictionType, summary.snippet.trainInfo, algorithm);
                    partitionSummaries.add(tsmi);
                }
            }
            catch (IOException e) {
                logger.warn((Object)"Failed to retrieve partition extract: ", (Throwable)e);
            }
            return partitionSummaries;
        }
    }

    public static abstract class TrainingSessionDetails {
        public int done;
        public int failed;
        public int aborted;
        public int running;
        public String projectKey;
        public String analysisId;
        public String mlTaskId;
        public String mlTaskType;
        public String metric;
        public List<TrainingSessionModelInfo> models = new ArrayList<TrainingSessionModelInfo>();
        public MLTask.DiagnosticsSettings diagnosticsSettings;
        public List<FeatureHandlingReport> featureHandlings = new ArrayList<FeatureHandlingReport>();

        public TrainingSessionResultSummary summarizeSessionResults() {
            TrainingSessionResultSummary ret = new TrainingSessionResultSummary();
            ret.done = this.done;
            ret.failed = this.failed;
            ret.aborted = this.aborted;
            ret.running = this.running;
            return ret;
        }

        protected abstract TrainingSessionModelInfo newTrainingSessionModelInfo(FullModelId var1) throws IOException;

        protected abstract void update(FullModelId var1) throws IOException;

        protected void updateFeatureHandlings(FullModelId fmi) throws IOException {
        }
    }

    public static class ClusteringTrainingSessionDetails
    extends TrainingSessionDetails {
        public ClusteringTrainingSessionDetails() {
            this.mlTaskType = "CLUSTERING";
        }

        @Override
        protected TrainingSessionModelInfo newTrainingSessionModelInfo(FullModelId fmi) throws IOException {
            ClusteringTrainingSessionModelInfo tsmi = new ClusteringTrainingSessionModelInfo();
            PreTrainClusteringModelingParams modeling = fmi.parseModelFile("rmodeling_params.json", PreTrainClusteringModelingParams.class);
            tsmi.algorithm = modeling.algorithm;
            ModelTrainInfo trainInfo = fmi.parseModelFile("train_info.json", ModelTrainInfo.class);
            tsmi.totalTime = trainInfo.trainingTime + trainInfo.preprocessingTime;
            tsmi.diagnosticsCount = MLDiagnostics.countDiagnostics(fmi);
            tsmi.state = trainInfo.state;
            return tsmi;
        }

        @Override
        protected void update(FullModelId fmi) throws IOException {
            PreTrainClusteringModelingParams modeling = fmi.parseModelFile("rmodeling_params.json", PreTrainClusteringModelingParams.class);
            this.metric = modeling.metrics.evaluationMetric.toString();
            ResolvedClusteringCoreParams coreParams = (ResolvedClusteringCoreParams)fmi.getResolvedCoreParams();
            this.diagnosticsSettings = coreParams.diagnosticsSettings;
        }
    }

    public static class TrainingSessionResultSummary {
        public int done;
        public int failed;
        public int aborted;
        public int running;
    }

    public static abstract class TrainingSessionModelInfo {
        public long totalTime;
        public ModelTrainInfo.ModelTrainState state;
        public Map<WarningsContext.WarningType, Integer> diagnosticsCount;
    }

    public static class PredictionTrainingSessionModelInfo
    extends TrainingSessionModelInfo {
        PreTrainPredictionModelingParams.Algorithm algorithm;
        public ModelTrainInfo.PreSearchDescription preSearch;
        public ModelTrainInfo.PostSearchDescription postSearch;
        public Long trainRecords;
        public Long testRecords;
        public Long fullRecords;
        public Double auc;
        public Double f1;
        public Double r2;
        public List<PredictionTrainingSessionModelInfo> partitionSummaries;

        public PredictionTrainingSessionModelInfo(FullModelId fmi, PredictionMLTask.PredictionType predictionType, ModelTrainInfo trainInfo, PreTrainPredictionModelingParams.Algorithm algorithm) throws IOException {
            this.algorithm = algorithm;
            if (fmi.getModelFile("perf.json").exists()) {
                switch (predictionType) {
                    case BINARY_CLASSIFICATION: {
                        BinaryClassificationModelPerf perf = fmi.parseModelFile("perf.json", BinaryClassificationModelPerf.class);
                        int thrIndex = perf.thresholdIndex(perf.optimalThreshold);
                        this.f1 = perf.perCutData.f1[thrIndex];
                        this.auc = perf.tiMetrics.auc;
                        break;
                    }
                    case MULTICLASS: {
                        MulticlassModelPerf perf = fmi.parseModelFile("perf.json", MulticlassModelPerf.class);
                        this.f1 = perf.metrics.f1;
                        this.auc = perf.metrics.mrocAUC;
                        break;
                    }
                    case REGRESSION: {
                        RegressionModelPerf perf = fmi.parseModelFile("perf.json", RegressionModelPerf.class);
                        this.r2 = perf.metrics.r2;
                        break;
                    }
                }
            }
            this.preSearch = trainInfo.preSearchDescription;
            this.postSearch = trainInfo.postSearchDescription;
            this.fullRecords = trainInfo.fullRows;
            this.testRecords = trainInfo.testRows;
            this.trainRecords = trainInfo.trainRows;
            this.totalTime = trainInfo.trainingTime + trainInfo.preprocessingTime;
            this.diagnosticsCount = MLDiagnostics.countDiagnostics(fmi);
            this.state = trainInfo.state;
        }
    }

    public static class ClusteringTrainingSessionModelInfo
    extends TrainingSessionModelInfo {
        PreTrainClusteringModelingParams.Algorithm algorithm;
    }

    public static class FeatureHandlingReport {
        double[] scores;
        String[] handlings;
        public static final int MAX_NB_FEATURES = 100;
        public static final Pattern pattern = Pattern.compile("(?:(num_flagonly|num_binarized|num_quantized|datetime_cyclical|NUM_DERIVATIVE|dummy|hashing|cat_flagpresence|ordinal|frequency|glmm|impact|hashvect|countvec|tfidfvec|sentence_vec|interaction):(?:.*))|(?:_$)");

        private static String parseFeatureName(String featureName) {
            Matcher matcher = pattern.matcher(featureName);
            String parsed = matcher.matches() ? (matcher.group(0).equals("_") ? "custom" : matcher.group(1)) : "num_regular";
            return parsed;
        }

        public static String[] parseFeatureHandlings(String[] featureNames) {
            if (featureNames.length <= 100) {
                String[] res = new String[featureNames.length];
                for (int i = 0; i < featureNames.length; ++i) {
                    res[i] = FeatureHandlingReport.parseFeatureName(featureNames[i]);
                }
                return res;
            }
            String[] res = new String[100];
            for (int i = 0; i < 50; ++i) {
                res[i] = FeatureHandlingReport.parseFeatureName(featureNames[i]);
                res[99 - i] = FeatureHandlingReport.parseFeatureName(featureNames[featureNames.length - 1 - i]);
            }
            return res;
        }
    }
}

