/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml;

import com.dataiku.dip.analysis.coreservices.TrainingSessionDetailsService;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.ResolvedCoreParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionCoreParams;
import com.dataiku.dip.containers.exec.ContainerExecConfigSelector;
import com.dataiku.dip.containers.exec.WorkloadType;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.dataflow.jobrunner.status.SerializedJobActivityStatus;
import com.dataiku.dip.reports.IReflectedEventsService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.notifications.DSSEvent;
import com.dataiku.dip.server.notifications.backend.ReflectedEventEvent;
import com.dataiku.dip.server.services.PubSubService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.org.apache.commons.codec.digest.DigestUtils;
import com.google.gson.Gson;
import com.google.gson.JsonDeserializationContext;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParseException;
import com.google.gson.JsonSerializationContext;
import java.io.IOException;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import javax.annotation.PostConstruct;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class MLTrainingTrackingService {
    @Autowired
    private PubSubService pubSubService;
    private static final String[] LIBRARIES_TO_REPORT;
    private static DKULogger logger;

    @PostConstruct
    public void init() {
        this.pubSubService.subscribe("job-done", evt -> {
            HashMap<String, List> activitiesPerModel = new HashMap<String, List>();
            for (SerializedJobActivityStatus activity : evt.status.activities.values()) {
                if (activity.activityType != SerializedJobActivityStatus.ActivityType.RECIPE || !Objects.equals(activity.recipeType, "prediction_training") && !Objects.equals(activity.recipeType, "clustering_training") || activity.targets.isEmpty()) continue;
                SerializedJobActivityStatus.SerializedTargetDS target = activity.targets.get(0);
                List activities = activitiesPerModel.getOrDefault(target.id, new ArrayList());
                activities.add(activity);
                activitiesPerModel.put(target.id, activities);
            }
            if (activitiesPerModel.isEmpty()) {
                return;
            }
            for (List activities : activitiesPerModel.values()) {
                String algorithm;
                TrainingSessionDetailsService.TrainingSessionResultSummary results = new TrainingSessionDetailsService.TrainingSessionResultSummary();
                FullModelId fmi = null;
                boolean includeVersions = false;
                for (SerializedJobActivityStatus activity : activities) {
                    SerializedJobActivityStatus.SerializedTargetDS target = activity.targets.get(0);
                    String[] idParts = target.id.split("\\.");
                    if (target.modelVersionId != null) {
                        ++results.done;
                        includeVersions = true;
                        fmi = new FullModelId(idParts[0], idParts[1], target.modelVersionId);
                        continue;
                    }
                    ++results.failed;
                    Transaction t = ((TransactionService)SpringUtils.getBean(TransactionService.class)).beginRead();
                    try {
                        SavedModel sm = (SavedModel)((SavedModelsDAO)SpringUtils.getBean(SavedModelsDAO.class)).getOrNullUnsafe(idParts[0], idParts[1]);
                        if (sm == null || sm.activeVersion == null || fmi != null) continue;
                        fmi = new FullModelId(idParts[0], idParts[1], sm.activeVersion);
                    }
                    finally {
                        if (t == null) continue;
                        t.close();
                    }
                }
                if (fmi == null) continue;
                TrainingDetails payload = TrainingDetails.forObject(fmi.projectKey, fmi.smId, ObjectType.SAVED_MODEL);
                payload.results = results;
                this.enrichPayloadFromCoreParams(payload, fmi);
                if (payload.taskType == null) continue;
                if (includeVersions) {
                    this.enrichPayloadFromModelTrainInfo(payload, fmi);
                }
                if ((algorithm = fmi.getAlgorithmSlug()) != null) {
                    payload.algorithms.put(algorithm, payload.algorithms.getOrDefault(algorithm, 0) + 1);
                }
                this.publishTrainingTrackingEvent(payload);
            }
        });
        this.pubSubService.subscribe("mltask-state-change", evt -> {
            if (evt.isRunning) {
                return;
            }
            TrainingDetails payload = TrainingDetails.forObject(evt.projectKey, evt.analysisId, ObjectType.ANALYSIS);
            payload.results = evt.results;
            MLTaskLoc loc = new MLTaskLoc(evt.projectKey, evt.analysisId, evt.taskId);
            List<FullModelId> modelIds = loc.listModelIds(evt.sessionId).stream().filter(modelId -> !modelId.isModelPartition()).toList();
            if (!modelIds.isEmpty()) {
                FullModelId fmi = modelIds.get(0);
                this.enrichPayloadFromCoreParams(payload, fmi);
                this.enrichPayloadFromModelTrainInfo(payload, fmi);
            }
            for (FullModelId fmi : modelIds) {
                String algorithm = fmi.getAlgorithmSlug();
                if (algorithm == null) continue;
                payload.algorithms.put(algorithm, payload.algorithms.getOrDefault(algorithm, 0) + 1);
            }
            this.publishTrainingTrackingEvent(payload);
        });
        logger.debug((Object)"Done initializing ML usage tracking service");
    }

    private void enrichPayloadFromCoreParams(TrainingDetails payload, FullModelId fmi) {
        try {
            ResolvedCoreParams coreParams = fmi.getResolvedCoreParams();
            payload.backendType = coreParams.backendType;
            payload.taskType = coreParams.taskType;
            if (coreParams.taskType == MLTask.MLTaskType.PREDICTION) {
                payload.predictionType = ((ResolvedPredictionCoreParams)coreParams).prediction_type;
            }
            payload.codeEnvKind = coreParams.executionParams.envName == null ? CodeEnvironmentKind.BUILTIN : (coreParams.executionParams.envName.startsWith("INTERNAL_") ? CodeEnvironmentKind.INTERNAL : CodeEnvironmentKind.CUSTOM);
            switch (coreParams.executionParams.containerSelection.containerMode) {
                case EXPLICIT_CONTAINER: {
                    payload.executionLocation = ExecutionLocation.CONTAINER;
                    break;
                }
                case NONE: {
                    payload.executionLocation = ExecutionLocation.BACKEND;
                    break;
                }
                default: {
                    if (new ContainerExecConfigSelector().workloadRunsContainerized_autoTXN(fmi.getProjectKey(), WorkloadType.USER_CODE)) {
                        payload.executionLocation = ExecutionLocation.CONTAINER;
                        break;
                    }
                    payload.executionLocation = ExecutionLocation.BACKEND;
                    break;
                }
            }
        }
        catch (IOException e) {
            logger.warn((Object)"Could not read resolved core params", (Throwable)e);
        }
    }

    private void enrichPayloadFromModelTrainInfo(TrainingDetails payload, FullModelId fmi) {
        Optional<ModelTrainInfo> mti = fmi.getTrainModelInfo();
        if (mti.isPresent()) {
            payload.pythonVersion = mti.get().pythonVersion;
            payload.libraries = this.getLibrariesVersions(mti.get());
        }
    }

    private Map<String, String> getLibrariesVersions(ModelTrainInfo modelTrainInfo) {
        HashMap<String, String> result = new HashMap<String, String>();
        for (String library : LIBRARIES_TO_REPORT) {
            if (modelTrainInfo.packagesVersion == null || !modelTrainInfo.packagesVersion.has(library)) continue;
            result.put(library, modelTrainInfo.packagesVersion.get(library).getAsString());
        }
        return result;
    }

    private void publishTrainingTrackingEvent(TrainingDetails payload) {
        IReflectedEventsService.ReflectedEvent evt = new IReflectedEventsService.ReflectedEvent("ml-training-finished", JSON.toJsonObject((Object)payload));
        logger.info((Object)"Publishing ml-training-finished reflected event");
        this.pubSubService.publish((DSSEvent)new ReflectedEventEvent(evt));
    }

    static {
        JSON.registerAdapter(TrainingDetails.class, (Object)new TrainingDetailsSerializer());
        LIBRARIES_TO_REPORT = new String[]{"gluonts", "lightgbm", "pandas", "scikit-learn", "scipy", "statsmodels", "torch", "tensorflow", "xgboost"};
        logger = DKULogger.getLogger((String)"dku.ml.tracking");
    }

    public static class TrainingDetails {
        String projecth;
        String objecth;
        String objectType;
        MLTask.MLTaskType taskType;
        PredictionMLTask.PredictionType predictionType;
        MLTask.BackendType backendType;
        Map<String, Integer> algorithms = new HashMap<String, Integer>();
        ExecutionLocation executionLocation;
        CodeEnvironmentKind codeEnvKind;
        String pythonVersion;
        Map<String, String> libraries = new HashMap<String, String>();
        TrainingSessionDetailsService.TrainingSessionResultSummary results;

        public static TrainingDetails forObject(String projectKey, String objectId, ObjectType objectType) {
            TrainingDetails details = new TrainingDetails();
            details.projecth = DigestUtils.md5Hex((String)projectKey);
            details.objecth = DigestUtils.md5Hex((String)(projectKey + "." + objectId));
            details.objectType = String.valueOf((Object)objectType);
            return details;
        }
    }

    private static enum CodeEnvironmentKind {
        BUILTIN,
        INTERNAL,
        CUSTOM;

    }

    private static enum ExecutionLocation {
        BACKEND,
        CONTAINER;

    }

    private static enum ObjectType {
        ANALYSIS,
        SAVED_MODEL;

    }

    public static class TrainingDetailsSerializer
    implements JSON.Adapter<TrainingDetails> {
        public JsonElement serialize(TrainingDetails details, Type type, JsonSerializationContext context) {
            JsonObject ret = (JsonObject)JSON.parse((String)new Gson().toJson((Object)details), JsonObject.class);
            for (Map.Entry<String, Integer> entry : details.algorithms.entrySet()) {
                ret.addProperty("algorithms__" + entry.getKey().toLowerCase(), (Number)entry.getValue());
            }
            ret.remove("algorithms");
            for (Map.Entry<String, Object> entry : details.libraries.entrySet()) {
                ret.addProperty("libraries__" + entry.getKey().toLowerCase(), (String)entry.getValue());
            }
            ret.remove("libraries");
            ret.addProperty("results__aborted", (Number)details.results.aborted);
            ret.addProperty("results__done", (Number)details.results.done);
            ret.addProperty("results__failed", (Number)details.results.failed);
            ret.remove("results");
            return ret;
        }

        public TrainingDetails deserialize(JsonElement jsonElement, Type type, JsonDeserializationContext jsonDeserializationContext) throws JsonParseException {
            assert (false) : "Not implemented";
            return null;
        }
    }
}

