/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.shared;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.SplitParams;
import com.dataiku.dip.analysis.model.clustering.ClusteringMLTask;
import com.dataiku.dip.analysis.model.clustering.PreTrainClusteringModelingParams;
import com.dataiku.dip.analysis.model.core.AnalysisCoreParams;
import com.dataiku.dip.analysis.model.core.ModelUserMeta;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.DeepHubPreTrainModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.coremodel.SerializedRecipe;
import com.dataiku.dip.coremodel.SimpleKeyValue;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.mec.drift.DriftParams;
import com.dataiku.dip.partitioning.Partition;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKUDateUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.NotImplementedException;
import com.dataiku.dip.variables.VariablesContext;
import com.dataiku.dip.variables.VariablesService;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang.StringUtils;

public class EvaluationLabelsHelper {
    public static final String SEPARATOR = ":";
    public static final String TRAIN_DATASET_NAMESPACE = "trainDataset";
    public static final String TEST_DATASET_NAMESPACE = "testDataset";
    public static final String EVALUATION_DATASET_NAMESPACE = "evaluationDataset";
    public static final String REFERENCE_DATASET_NAMESPACE = "referenceDataset";
    public static final String MODEL_NAMESPACE = "model";
    public static final String DEFAULT_VALUE = "true";
    public static final String ALGORITHM = "algorithm";
    public static final String META_LEARNER = "meta-learner";
    public static final String CAUSAL_LEARNING_METHOD = "learning-method";
    public static final String MODEL_NAME = "model-name";
    public static final String SESSION_NAME = "session-name";
    public static final String DATASET_NAME = "dataset-name";
    public static final String EVALUATION_NAMESPACE = "evaluation";
    public static final String DATE = "date";
    public static final String NAME = "name";
    public static final String PARTITIONS = "partitions";
    public static final String PARTITION_COUNT = "partition-count";
    public static final String DATA_DRIFT_COLUMN_HANDLING = "data-drift-column-handling";
    public static final String EMBEDDING_LLM_AS_A_JUDGE = "embedding-llm-as-a-judge";
    public static final String COMPLETION_LLM_AS_A_JUDGE = "completion-llm-as-a-judge";
    public static final String TS_NB_FORECAST_TIMESTEPS = "nb-evaluation-timesteps";
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.analysis.prediction");

    private static void addDatasetTags(SerializedDataset dataset, VariablesContext variablesContext, String prefix, Map<String, String> toFill) {
        EvaluationLabelsHelper.convertTagsToLabels(dataset.tags, variablesContext, prefix, toFill);
        String key = prefix + ":dataset-name";
        toFill.put(key, dataset.getDisplayName());
    }

    protected static void convertTagsToLabels(List<String> tags, VariablesContext variablesContext, String prefix, Map<String, String> toFill) {
        if (CollectionUtils.isNotEmpty(tags)) {
            for (String currentTag : tags) {
                String[] splittedCurrentTag = currentTag.split(SEPARATOR);
                String key = prefix + SEPARATOR + splittedCurrentTag[0];
                String value = splittedCurrentTag.length > 1 ? variablesContext.expand(String.join((CharSequence)SEPARATOR, Arrays.copyOfRange(splittedCurrentTag, 1, splittedCurrentTag.length))) : DEFAULT_VALUE;
                toFill.put(key, value);
            }
        }
    }

    private static void getLabelsFromDatasetsTags(String projectKey, SplitParams splitParams, String defaultInputDataset, DatasetsDAO datasetsDAO, VariablesContext variablesContext, Map<String, String> toFill) throws IOException {
        String trainDatasetName = null;
        String testDatasetName = null;
        if (null != splitParams) {
            switch (splitParams.ttPolicy) {
                case EXPLICIT_FILTERING_SINGLE_DATASET: {
                    testDatasetName = splitParams.efsdDatasetSmartName;
                    break;
                }
                case EXPLICIT_FILTERING_TWO_DATASETS: {
                    trainDatasetName = splitParams.eftdTrain.datasetSmartName;
                    testDatasetName = splitParams.eftdTest.datasetSmartName;
                    break;
                }
                case SPLIT_SINGLE_DATASET: {
                    testDatasetName = splitParams.ssdDatasetSmartName;
                    break;
                }
                default: {
                    throw new NotImplementedException("Unknown Train TestPolicy: " + String.valueOf((Object)splitParams.ttPolicy));
                }
            }
        }
        if (StringUtils.isEmpty(testDatasetName)) {
            testDatasetName = defaultInputDataset;
        }
        SerializedDataset testDataset = (SerializedDataset)datasetsDAO.getOrNull(AnyLoc.resolveSmart(projectKey, testDatasetName));
        EvaluationLabelsHelper.addDatasetTags(testDataset, variablesContext, TEST_DATASET_NAMESPACE, toFill);
        EvaluationLabelsHelper.addDatasetTags(testDataset, variablesContext, EVALUATION_DATASET_NAMESPACE, toFill);
        if (null == trainDatasetName) {
            trainDatasetName = testDatasetName;
        }
        SerializedDataset trainDataset = (SerializedDataset)datasetsDAO.getOrNull(AnyLoc.resolveSmart(projectKey, trainDatasetName));
        EvaluationLabelsHelper.addDatasetTags(trainDataset, variablesContext, TRAIN_DATASET_NAMESPACE, toFill);
    }

    protected static void expandLabels(String namespace, List<SimpleKeyValue> labels, VariablesContext variablesContext, Map<String, String> toFill) {
        Object prefix = StringUtils.isNotEmpty((String)namespace) ? namespace + SEPARATOR : "";
        if (CollectionUtils.isNotEmpty(labels)) {
            for (SimpleKeyValue curLabel : labels) {
                String labelKey = Objects.toString(curLabel.key, "");
                toFill.put((String)(labelKey.contains(SEPARATOR) ? labelKey : (String)prefix + labelKey), variablesContext.expand(curLabel.value));
            }
        }
    }

    protected static List<SimpleKeyValue> collectLabelsMap(Map<String, String> labels) {
        return labels.entrySet().stream().map(ent -> new SimpleKeyValue((String)ent.getKey(), (String)ent.getValue())).collect(Collectors.toList());
    }

    public static List<SimpleKeyValue> getTrainTimeLabels(MLTaskLoc loc, ClusteringMLTask task, AnalysisCoreParams acp) {
        List<SimpleKeyValue> list;
        block9: {
            TransactionService transactionService = (TransactionService)SpringUtils.getBean(TransactionService.class);
            Transaction t = transactionService.retrieveOrBeginRead();
            try {
                DatasetsDAO datasetsDAO = (DatasetsDAO)SpringUtils.getBean(DatasetsDAO.class);
                VariablesService variablesService = (VariablesService)SpringUtils.getBean(VariablesService.class);
                VariablesContext variablesContext = variablesService.getForProject(loc.analysisProjectKey);
                String testDatasetName = task.sampling.datasetSmartName;
                if (StringUtils.isEmpty((String)testDatasetName)) {
                    testDatasetName = acp.inputDatasetSmartName;
                }
                SerializedDataset testDataset = (SerializedDataset)datasetsDAO.getOrNull(AnyLoc.resolveSmart(loc.analysisProjectKey, testDatasetName));
                HashMap<String, String> labelsMap = new HashMap<String, String>();
                EvaluationLabelsHelper.addDatasetTags(testDataset, variablesContext, TRAIN_DATASET_NAMESPACE, labelsMap);
                EvaluationLabelsHelper.addDatasetTags(testDataset, variablesContext, TEST_DATASET_NAMESPACE, labelsMap);
                EvaluationLabelsHelper.addDatasetTags(testDataset, variablesContext, EVALUATION_DATASET_NAMESPACE, labelsMap);
                EvaluationLabelsHelper.expandLabels(MODEL_NAMESPACE, task.labels, variablesContext, labelsMap);
                list = EvaluationLabelsHelper.collectLabelsMap(labelsMap);
                if (t == null) break block9;
            }
            catch (Throwable throwable) {
                try {
                    if (t != null) {
                        try {
                            t.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (IOException e) {
                    logger.error((Object)"Error while extracting dataset labels", (Throwable)e);
                    return new ArrayList<SimpleKeyValue>();
                }
            }
            t.close();
        }
        return list;
    }

    public static List<SimpleKeyValue> getTrainTimeLabels(MLTaskLoc loc, PredictionMLTask task, AnalysisCoreParams acp) {
        List<SimpleKeyValue> list;
        block8: {
            TransactionService transactionService = (TransactionService)SpringUtils.getBean(TransactionService.class);
            Transaction t = transactionService.retrieveOrBeginRead();
            try {
                HashMap<String, String> labelsMap = new HashMap<String, String>();
                DatasetsDAO datasetsDAO = (DatasetsDAO)SpringUtils.getBean(DatasetsDAO.class);
                VariablesService variablesService = (VariablesService)SpringUtils.getBean(VariablesService.class);
                VariablesContext variablesContext = variablesService.getForProject(loc.analysisProjectKey);
                EvaluationLabelsHelper.getLabelsFromDatasetsTags(loc.analysisProjectKey, task.splitParams, acp.inputDatasetSmartName, datasetsDAO, variablesContext, labelsMap);
                EvaluationLabelsHelper.expandLabels(MODEL_NAMESPACE, task.labels, variablesContext, labelsMap);
                list = EvaluationLabelsHelper.collectLabelsMap(labelsMap);
                if (t == null) break block8;
            }
            catch (Throwable throwable) {
                try {
                    if (t != null) {
                        try {
                            t.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (IOException e) {
                    logger.error((Object)"Error while extracting dataset labels", (Throwable)e);
                    return new ArrayList<SimpleKeyValue>();
                }
            }
            t.close();
        }
        return list;
    }

    public static List<SimpleKeyValue> getTrainTimeLabels_T(SerializedRecipe recipe, SplitParams splitParams, SavedModel sm, String defaultInputDataset, String algorithmName, ModelTrainInfo mti, ModelUserMeta mum) throws IOException {
        DatasetsDAO datasetsDAO = (DatasetsDAO)SpringUtils.getBean(DatasetsDAO.class);
        VariablesService variablesService = (VariablesService)SpringUtils.getBean(VariablesService.class);
        VariablesContext variablesContext = variablesService.getForProject(recipe.projectKey);
        HashMap<String, String> labelsMap = new HashMap<String, String>();
        EvaluationLabelsHelper.getLabelsFromDatasetsTags(recipe.projectKey, splitParams, defaultInputDataset, datasetsDAO, variablesContext, labelsMap);
        EvaluationLabelsHelper.expandLabels(MODEL_NAMESPACE, recipe.labels, variablesContext, labelsMap);
        EvaluationLabelsHelper.convertTagsToLabels(sm.tags, variablesContext, MODEL_NAMESPACE, labelsMap);
        String trainEndDateTime = DKUDateUtils.isoFormatLocal((long)mti.endTime);
        labelsMap.put("model:date", trainEndDateTime);
        labelsMap.put("evaluation:date", trainEndDateTime);
        labelsMap.put("model:algorithm", algorithmName);
        if (sm.miniTask instanceof PredictionMLTask.CausalPredictionMLTask) {
            FullModelId fmi = new FullModelId(sm.projectKey, sm.id, sm.activeVersion);
            PreTrainPredictionModelingParams preTrainModelingParams = fmi.parseModelFile("rmodeling_params.json", PreTrainPredictionModelingParams.class);
            labelsMap.put("model:learning-method", preTrainModelingParams.causal_method.name());
            if (PredictionModelingParams.CausalLearningMethod.META_LEARNER.equals((Object)preTrainModelingParams.causal_method)) {
                labelsMap.put("model:meta-learner", preTrainModelingParams.meta_learner.displayName);
            }
        }
        labelsMap.put("model:name", mum.name);
        if (mum.sessionName != null) {
            labelsMap.put("model:session-name", mum.sessionName);
        }
        return EvaluationLabelsHelper.collectLabelsMap(labelsMap);
    }

    public static List<SimpleKeyValue> getEvaluationTimeLabels_T(String projectKey, ModelUserMeta mum, SerializedDataset evaluationDataset, SerializedDataset referenceDataset, List<SimpleKeyValue> recipeLabels, List<Partition> partitionsList, Map<String, DriftParams.PerColumnDriftParam> dataDriftColumnHandling) {
        VariablesService variablesService = (VariablesService)SpringUtils.getBean(VariablesService.class);
        VariablesContext variablesContext = variablesService.getForProject(projectKey);
        HashMap<String, String> labelsMap = new HashMap<String, String>();
        if (null != mum) {
            EvaluationLabelsHelper.expandLabels(MODEL_NAMESPACE, mum.labels, variablesContext, labelsMap);
            String modelUserName = "evaluation:model-name";
            labelsMap.put("evaluation:model-name", mum.name);
        }
        EvaluationLabelsHelper.addDatasetTags(evaluationDataset, variablesContext, EVALUATION_DATASET_NAMESPACE, labelsMap);
        if (referenceDataset != null) {
            EvaluationLabelsHelper.addDatasetTags(referenceDataset, variablesContext, REFERENCE_DATASET_NAMESPACE, labelsMap);
        }
        EvaluationLabelsHelper.addEvaluationTime(labelsMap);
        EvaluationLabelsHelper.expandLabels(null, recipeLabels, variablesContext, labelsMap);
        EvaluationLabelsHelper.addPartitions(labelsMap, partitionsList);
        EvaluationLabelsHelper.addDataDriftColumnHandling(labelsMap, dataDriftColumnHandling);
        return EvaluationLabelsHelper.collectLabelsMap(labelsMap);
    }

    public static List<SimpleKeyValue> getTSEvaluationTimeLabels_T(String projectKey, ModelUserMeta mum, SerializedDataset evaluationDataset, List<SimpleKeyValue> recipeLabels, List<Partition> partitionList, Map<String, DriftParams.PerColumnDriftParam> dataDriftColumnHandling, Integer maxNbForecastTimeSteps) {
        List<SimpleKeyValue> ret = EvaluationLabelsHelper.getEvaluationTimeLabels_T(projectKey, mum, evaluationDataset, null, recipeLabels, partitionList, dataDriftColumnHandling);
        ret = EvaluationLabelsHelper.setUniqueKeyValue(ret, "evaluation:nb-evaluation-timesteps", Integer.toString(maxNbForecastTimeSteps));
        return ret;
    }

    public static List<SimpleKeyValue> getLLMEvaluationTimeLabels_T(String projectKey, SerializedDataset evaluationDataset, List<SimpleKeyValue> recipeLabels, List<Partition> partitionList, @Nullable String embeddingLLMFriendlyName, @Nullable String completionLLMFriendlyName) {
        List<SimpleKeyValue> ret = EvaluationLabelsHelper.getEvaluationTimeLabels_T(projectKey, null, evaluationDataset, null, recipeLabels, partitionList, null);
        if (StringUtils.isNotBlank((String)embeddingLLMFriendlyName)) {
            ret = EvaluationLabelsHelper.setUniqueKeyValue(ret, "evaluation:embedding-llm-as-a-judge", embeddingLLMFriendlyName);
        }
        if (StringUtils.isNotBlank((String)completionLLMFriendlyName)) {
            ret = EvaluationLabelsHelper.setUniqueKeyValue(ret, "evaluation:completion-llm-as-a-judge", completionLLMFriendlyName);
        }
        return ret;
    }

    private static List<SimpleKeyValue> setUniqueKeyValue(List<SimpleKeyValue> list, String key, String value) {
        if (StringUtils.isNotBlank((String)value)) {
            list = list.stream().filter(kv -> !StringUtils.equals((String)key, (String)kv.key)).collect(Collectors.toList());
            list.add(new SimpleKeyValue(key, value));
        }
        return list;
    }

    public static List<SimpleKeyValue> getEvaluationTimeLabels_T(String projectKey, ModelUserMeta mum, SerializedDataset evaluationDataset, List<SimpleKeyValue> recipeLabels, List<Partition> partitionList, Map<String, DriftParams.PerColumnDriftParam> dataDriftColumnHandling) {
        return EvaluationLabelsHelper.getEvaluationTimeLabels_T(projectKey, mum, evaluationDataset, null, recipeLabels, partitionList, dataDriftColumnHandling);
    }

    public static List<SimpleKeyValue> setAlgorithmModelNameAndSessionNameToLabels(List<SimpleKeyValue> existingLabels, WorkSet.ModelingSet ms) {
        String algorithmKey = "model:algorithm";
        String nameKey = "model:name";
        String sessionNameKey = "model:session-name";
        List<SimpleKeyValue> labels = existingLabels.stream().filter(kv -> !Arrays.asList("model:algorithm", "model:name", "model:session-name").contains(kv.key)).collect(Collectors.toList());
        if (ms.modelingParams instanceof PreTrainPredictionModelingParams) {
            PreTrainPredictionModelingParams predictionModelingParams = (PreTrainPredictionModelingParams)ms.modelingParams;
            if (predictionModelingParams.causal_method != null) {
                String learningMethodKey = "model:learning-method";
                labels.add(new SimpleKeyValue("model:learning-method", predictionModelingParams.causal_method.name()));
                if (PredictionModelingParams.CausalLearningMethod.META_LEARNER.equals((Object)predictionModelingParams.causal_method)) {
                    String metaLearnerKey = "model:meta-learner";
                    labels.add(new SimpleKeyValue("model:meta-learner", predictionModelingParams.meta_learner.displayName));
                }
            }
            labels.add(new SimpleKeyValue("model:algorithm", predictionModelingParams.algorithm.name()));
        } else if (ms.modelingParams instanceof DeepHubPreTrainModelingParams) {
            labels.add(new SimpleKeyValue("model:algorithm", ms.modelingParams.generateName()));
        } else if (ms.modelingParams instanceof PreTrainClusteringModelingParams) {
            labels.add(new SimpleKeyValue("model:algorithm", ((PreTrainClusteringModelingParams)ms.modelingParams).algorithm.name()));
        } else {
            throw new IllegalArgumentException("Unsupported modeling params: " + ms.modelingParams.getClass().getSimpleName());
        }
        labels.add(new SimpleKeyValue("model:name", ms.userMeta.name));
        if (ms.userMeta.sessionName != null) {
            labels.add(new SimpleKeyValue("model:session-name", ms.userMeta.sessionName));
        }
        return labels;
    }

    public static List<SimpleKeyValue> setModelNameToLabels(List<SimpleKeyValue> existingLabels, String modelName) {
        String nameKey = "model:name";
        List<SimpleKeyValue> labels = EvaluationLabelsHelper.setUniqueKeyValue(existingLabels, "model:name", modelName);
        return labels;
    }

    public static List<SimpleKeyValue> setTrainTime(List<SimpleKeyValue> existingLabels, long trainTime) {
        String modelKey = "model:date";
        String evaluationKey = "evaluation:date";
        String trainTimeStr = DKUDateUtils.isoFormatLocal((long)trainTime);
        List<SimpleKeyValue> labels = EvaluationLabelsHelper.setUniqueKeyValue(existingLabels, "model:date", trainTimeStr);
        labels = EvaluationLabelsHelper.setUniqueKeyValue(labels, "evaluation:date", trainTimeStr);
        return labels;
    }

    private static void addEvaluationTime(Map<String, String> labelsMap) {
        String key = "evaluation:date";
        labelsMap.put("evaluation:date", DKUDateUtils.isoFormatLocalNow());
    }

    public static List<SimpleKeyValue> getMLflowSMVLabels(String modelName) {
        String modelNameKey = "model:name";
        ArrayList<SimpleKeyValue> labels = new ArrayList<SimpleKeyValue>();
        labels.add(new SimpleKeyValue("model:name", modelName));
        return labels;
    }

    public static void addPartitions(Map<String, String> labelsMap, List<Partition> partitionList) {
        String partitionsKey = "evaluationDataset:partitions";
        String partitionCountKey = "evaluationDataset:partition-count";
        if (partitionList == null || partitionList.isEmpty() || partitionList.size() == 1 && "NP".equals(partitionList.get(0).id())) {
            return;
        }
        String partitions = partitionList.stream().map(Partition::id).collect(Collectors.joining(", "));
        labelsMap.put("evaluationDataset:partitions", partitions);
        labelsMap.put("evaluationDataset:partition-count", String.valueOf(partitionList.size()));
    }

    public static void addDataDriftColumnHandling(Map<String, String> labelsMap, Map<String, DriftParams.PerColumnDriftParam> dataDriftColumnHandling) {
        String dataDriftColumnHandlingKey = "evaluation:data-drift-column-handling";
        if (MapUtils.isEmpty(dataDriftColumnHandling)) {
            return;
        }
        labelsMap.put("evaluation:data-drift-column-handling", dataDriftColumnHandling.entrySet().stream().map(entry -> (String)entry.getKey() + "(" + (!((DriftParams.PerColumnDriftParam)entry.getValue()).enabled ? "IGNORED" : ((DriftParams.PerColumnDriftParam)entry.getValue()).handling.name()) + ")").sorted().collect(Collectors.joining(", ")));
    }
}

