/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.prediction.DesignImagesDataService;
import com.dataiku.dip.analysis.ml.prediction.PredictedDataService;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.core.AnalysisCoreParams;
import com.dataiku.dip.analysis.model.core.ResolvedCoreParams;
import com.dataiku.dip.analysis.model.prediction.DeepHubColumnFormat;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.ResolvedDeepHubPredictionCoreParams;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.memimpl.MemColumn;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.shaker.filter.FilterRequest;
import com.dataiku.dip.shaker.filter.FilteringExecutor;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.shaker.server.MemScriptRunner;
import com.dataiku.dip.shaker.server.SerializedMemTableV2;
import com.dataiku.dip.shaker.server.SerializedTableChunk;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.polyjson.Mapping;
import com.dataiku.dip.utils.polyjson.PolyJSON;
import com.dataiku.j2ts.annotations.UIModel;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.directory.api.util.Strings;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class DeephubImagesDataService {
    public static final String DETECTION_COLUMN_NAME = "prediction";
    public static final String PAIRING_COLUMN_NAME = "pairing";
    public static final String ENRICHED_VALID_COLUMN_NAME = "enrichedValid";
    public static final String ENRICHED_FILTERED_COLUMN_NAME = "enrichedFiltered";
    @Autowired
    private PredictedDataService predictedDataService;
    @Autowired
    private DesignImagesDataService designImagesDataService;

    private static FilteringExecutor getDesignFilteringExecutor(PredictionMLTask task, DeepHubDesignFilterRequest deepHubFilterRequest) throws Exception {
        if (deepHubFilterRequest == null) {
            return null;
        }
        if (deepHubFilterRequest.targetCategories == null || deepHubFilterRequest.targetCategories.isEmpty()) {
            return null;
        }
        FilterRequest filterRequest = new FilterRequest();
        FilterRequest.FilterElement filterElement = new FilterRequest.FilterElement();
        filterElement.active = true;
        filterElement.column = task.targetVariable;
        filterElement.selectedValues = deepHubFilterRequest.targetCategories.toArray(new String[0]);
        switch (task.predictionType) {
            case DEEP_HUB_IMAGE_OBJECT_DETECTION: {
                filterElement.type = FilterRequest.FilterType.BOUNDING_BOX_FACET;
                break;
            }
            case DEEP_HUB_IMAGE_CLASSIFICATION: {
                filterElement.type = FilterRequest.FilterType.ALPHANUM_FACET;
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported prediction type for filtering: '" + String.valueOf((Object)task.predictionType) + "'");
            }
        }
        filterRequest.elements.add(filterElement);
        return new FilteringExecutor(filterRequest);
    }

    private MemScriptRunner.TableWithReport getDesignImagesDataTableWithReport(AuthCtx user, AnalysisCoreParams acp, PredictionMLTask task, DeepHubDesignFilterRequest filterRequest) throws Exception {
        FilteringExecutor filteringExecutor = DeephubImagesDataService.getDesignFilteringExecutor(task, filterRequest);
        return this.designImagesDataService.getDesignImagesDataTableWithReport(user, acp, filteringExecutor);
    }

    public SerializedMemTableV2 refreshDesignImagesDataSample(AuthCtx user, AnalysisCoreParams acp, int nbRows, PredictionMLTask task, DeepHubDesignFilterRequest filterRequest) throws Exception {
        FilteringExecutor filteringExecutor = DeephubImagesDataService.getDesignFilteringExecutor(task, filterRequest);
        return this.designImagesDataService.refreshDesignImagesDataSample(user, acp, nbRows, filteringExecutor);
    }

    public SerializedTableChunk getDesignImagesDataChunk(AuthCtx user, AnalysisCoreParams acp, PredictionMLTask task, int firstRow, int nbRows, DeepHubDesignFilterRequest filterRequest) throws Exception {
        FilteringExecutor filteringExecutor = DeephubImagesDataService.getDesignFilteringExecutor(task, filterRequest);
        return this.designImagesDataService.getDesignImagesDataChunk(user, acp, firstRow, nbRows, filteringExecutor);
    }

    public List<String> getRandomImagePaths(AuthCtx user, AnalysisCoreParams acp, PredictionMLTask task, int numImagePaths) throws Exception {
        MemScriptRunner.TableWithReport dataTable = this.getDesignImagesDataTableWithReport(user, acp, task, null);
        String pathColumnName = ((PredictionMLTask.DeepHubPredictionMLTask)task).pathColumn;
        MemColumn pathColumn = dataTable.table.column(pathColumnName);
        Random r = new Random(System.currentTimeMillis());
        ArrayList<String> ret = new ArrayList<String>();
        for (int i = 0; i < numImagePaths; ++i) {
            int imageIndex = r.nextInt(dataTable.table.nrows());
            ret.add(dataTable.table.rows.get(imageIndex).get(pathColumn));
        }
        return ret;
    }

    private DeepHubPredictionReportDerivedColumnsComputer getDerivedColumnsComputer(ResolvedDeepHubPredictionCoreParams deephubCoreParams, ComputerVisionPredictedFilter filter, FullModelId fmi) throws IOException {
        return switch (deephubCoreParams.prediction_type) {
            case PredictionMLTask.PredictionType.DEEP_HUB_IMAGE_OBJECT_DETECTION -> new ObjectDetectionDerivedColumnsComputer((ObjectDetectionPredictedFilter)filter, fmi.getHeadMLTask());
            case PredictionMLTask.PredictionType.DEEP_HUB_IMAGE_CLASSIFICATION -> new ImageClassificationDerivedColumnsComputer((ImageClassificationPredictedFilter)filter, fmi.getHeadMLTask());
            default -> throw new IllegalArgumentException("Unsupported prediction type for filtering: '" + String.valueOf((Object)deephubCoreParams.prediction_type) + "'");
        };
    }

    public SerializedMemTableV2 refreshPredictedImagesDataSample(FullModelId fmi, ComputerVisionPredictedFilter filter, int nbRows) throws Exception {
        SerializedShakerScript sss = fmi.parseSessionFile("script.json", SerializedShakerScript.class);
        MemScriptRunner.TableWithReport twr = this.getPredictedImagesDataTableWithReport(fmi, sss, filter);
        SerializedMemTableV2 result = new SerializedMemTableV2();
        result.fill(twr, sss, nbRows, Integer.MAX_VALUE);
        return result;
    }

    public SerializedTableChunk getPredictedImagesDataChunk(FullModelId fmi, int firstRow, Integer nbRows, ComputerVisionPredictedFilter filter) throws Exception {
        SerializedShakerScript sss = fmi.parseSessionFile("script.json", SerializedShakerScript.class);
        MemScriptRunner.TableWithReport twr = this.getPredictedImagesDataTableWithReport(fmi, sss, filter);
        SerializedTableChunk stc = new SerializedTableChunk(firstRow, nbRows, 0, twr.table.columnsList.size());
        stc.fill(twr.table, twr.filters, sss.coloring, sss.columnsSelection);
        return stc;
    }

    private MemScriptRunner.TableWithReport getPredictedImagesDataTableWithReport(FullModelId fmi, SerializedShakerScript sss, ComputerVisionPredictedFilter filter) throws Exception {
        ResolvedCoreParams coreParams = fmi.getResolvedCoreParams();
        assert (coreParams.backendType == MLTask.BackendType.DEEP_HUB);
        ResolvedDeepHubPredictionCoreParams deephubCoreParams = (ResolvedDeepHubPredictionCoreParams)coreParams;
        DeepHubPredictionReportDerivedColumnsComputer derivedColumnsComputer = this.getDerivedColumnsComputer(deephubCoreParams, filter, fmi);
        return this.predictedDataService.getUncachedFiltered_NT(fmi, sss.explorationSampling, derivedColumnsComputer.getFilterRequest(), derivedColumnsComputer, derivedColumnsComputer.getSortingRequest());
    }

    @UIModel
    public static class DeepHubDesignFilterRequest {
        List<String> targetCategories = new ArrayList<String>();
    }

    static class ObjectDetectionDerivedColumnsComputer
    extends PredictedDataService.BaseDerivedColumnsComputer
    implements DeepHubPredictionReportDerivedColumnsComputer {
        static final String SORT_COLUMN = "__dku__deephub_sorting_column";
        static final String FILTER_COLUMN = "__dku__deephub_filter_column";
        private final PredictionMLTask mlTask;
        private final ObjectDetectionPredictedFilter filter;
        private ColumnFactory columnFactory;
        private Column targetCol;
        private Column detectionCol;
        private Column pairingCol;
        private Column sortColumn;
        private Column filterColumn;
        private Column enrichedValidCol;
        private Column enrichedFilteredCol;

        ObjectDetectionDerivedColumnsComputer(ObjectDetectionPredictedFilter filter, MLTask mlTask) {
            this.filter = filter;
            this.mlTask = (PredictionMLTask)mlTask;
        }

        private boolean actualFilter() {
            if (this.filter == null) {
                return false;
            }
            boolean filterOnGroundTruth = this.filter.groundTruth != null && this.filter.groundTruth.type != CategoryType.ANY;
            boolean filterOnDetections = this.filter.detection != null && this.filter.detection.type != CategoryType.ANY;
            return filterOnGroundTruth || filterOnDetections;
        }

        private boolean actualSorting() {
            return this.filter != null && this.filter.sorting != null;
        }

        @Override
        public FilterRequest getFilterRequest() {
            if (!this.actualFilter()) {
                return null;
            }
            FilterRequest filterRequest = new FilterRequest();
            FilterRequest.FilterElement filterElement = new FilterRequest.FilterElement();
            filterElement.active = true;
            filterElement.column = FILTER_COLUMN;
            filterElement.type = FilterRequest.FilterType.NUMERICAL_FACET;
            filterElement.minValue = 0.0;
            filterRequest.elements.add(filterElement);
            return filterRequest;
        }

        @Override
        public List<SerializedShakerScript.TableSorting> getSortingRequest() {
            if (!this.actualSorting()) {
                return null;
            }
            ArrayList<SerializedShakerScript.TableSorting> ret = new ArrayList<SerializedShakerScript.TableSorting>();
            SerializedShakerScript.TableSorting tableSorting = new SerializedShakerScript.TableSorting();
            tableSorting.column = SORT_COLUMN;
            tableSorting.ascending = true;
            ret.add(tableSorting);
            return ret;
        }

        private List<DeepHubColumnFormat.EnrichedObjectDetectionPairedItem> buildEnriched(Row row, Column pairingCol, Column detectionCol, Column targetCol) {
            List<DeepHubColumnFormat.PairingItem> pairing = DeepHubColumnFormat.parseObjectDetectionPairing(row.get(pairingCol));
            List<DeepHubColumnFormat.ObjectDetectionPredictedItem> detections = DeepHubColumnFormat.parseObjectDetectionPrediction(row.get(detectionCol));
            List<DeepHubColumnFormat.ObjectDetectionTargetItem> groundTruths = DeepHubColumnFormat.parseObjectDetectionTarget(row.get(targetCol));
            List filteredPairing = pairing.stream().filter(p -> p.iou > (double)this.filter.minIOU).filter(p -> ((DeepHubColumnFormat.ObjectDetectionPredictedItem)detections.get((int)p.det_id)).confidence > this.filter.minConfidence).collect(Collectors.toList());
            List<DeepHubColumnFormat.EnrichedObjectDetectionPairedItem> res = filteredPairing.stream().map(p -> new DeepHubColumnFormat.EnrichedObjectDetectionPairedItem(p.iou, (DeepHubColumnFormat.ObjectDetectionPredictedItem)detections.get(p.det_id), (DeepHubColumnFormat.ObjectDetectionTargetItem)groundTruths.get(p.gt_id))).collect(Collectors.toList());
            Set pairedGroundTruths = filteredPairing.stream().map(p -> p.gt_id).collect(Collectors.toSet());
            for (int i = 0; i < groundTruths.size(); ++i) {
                if (pairedGroundTruths.contains(i)) continue;
                res.add(new DeepHubColumnFormat.EnrichedObjectDetectionPairedItem(groundTruths.get(i)));
            }
            Set pairedDetections = filteredPairing.stream().map(p -> p.det_id).collect(Collectors.toSet());
            for (int i = 0; i < detections.size(); ++i) {
                if (pairedDetections.contains(i)) continue;
                DeepHubColumnFormat.ObjectDetectionPredictedItem detection = detections.get(i);
                if (!(detection.confidence > this.filter.minConfidence)) continue;
                res.add(new DeepHubColumnFormat.EnrichedObjectDetectionPairedItem(detection));
            }
            return res;
        }

        private boolean filterEnrichedItem(DeepHubColumnFormat.EnrichedObjectDetectionPairedItem item) {
            if (!this.actualFilter()) {
                return true;
            }
            if (this.filter.groundTruth != null) {
                if (this.filter.groundTruth.type == CategoryType.NONE && item.groundTruth != null) {
                    return false;
                }
                if (!(this.filter.groundTruth.type != CategoryType.ONE || item.groundTruth != null && Strings.equals((String)item.groundTruth.category, (String)this.filter.groundTruth.value))) {
                    return false;
                }
            }
            if (this.filter.detection != null) {
                if (this.filter.detection.type == CategoryType.NONE && item.detection != null) {
                    return false;
                }
                if (!(this.filter.detection.type != CategoryType.ONE || item.detection != null && Strings.equals((String)item.detection.category, (String)this.filter.detection.value))) {
                    return false;
                }
            }
            return true;
        }

        private double getBestSortValue(List<DeepHubColumnFormat.EnrichedObjectDetectionPairedItem> enrichedValid) {
            int sign = this.filter.sorting.ascending ? 1 : -1;
            switch (this.filter.sorting.sortBy) {
                case IOU: {
                    return enrichedValid.stream().mapToDouble(e -> (double)sign * e.iou).max().orElse(Double.MAX_VALUE);
                }
                case CONFIDENCE: {
                    return enrichedValid.stream().filter(e -> e.detection != null).mapToDouble(e -> (float)sign * e.detection.confidence).max().orElse(Double.MAX_VALUE);
                }
            }
            throw new IllegalArgumentException("Unknown sorting mode " + String.valueOf((Object)this.filter.sorting.sortBy));
        }

        private void initializeColumnFactoryIfAbsent(ColumnFactory columnFactory) {
            if (this.columnFactory == null) {
                this.columnFactory = columnFactory;
                String targetColName = this.mlTask.targetVariable;
                this.targetCol = columnFactory.column(targetColName);
                this.detectionCol = columnFactory.column(DeephubImagesDataService.DETECTION_COLUMN_NAME);
                this.pairingCol = columnFactory.column(DeephubImagesDataService.PAIRING_COLUMN_NAME);
                this.sortColumn = columnFactory.column(SORT_COLUMN);
                this.filterColumn = columnFactory.column(FILTER_COLUMN);
                this.enrichedValidCol = columnFactory.column(DeephubImagesDataService.ENRICHED_VALID_COLUMN_NAME);
                this.enrichedFilteredCol = columnFactory.column(DeephubImagesDataService.ENRICHED_FILTERED_COLUMN_NAME);
            } else if (columnFactory != this.columnFactory) {
                throw new IllegalStateException("Different column factory already initialized");
            }
        }

        @Override
        public void compute(Row row, ColumnFactory columnFactory) throws IOException {
            this.initializeColumnFactoryIfAbsent(columnFactory);
            List<DeepHubColumnFormat.EnrichedObjectDetectionPairedItem> enriched = this.buildEnriched(row, this.pairingCol, this.detectionCol, this.targetCol);
            Map<Boolean, List<DeepHubColumnFormat.EnrichedObjectDetectionPairedItem>> partitioned = enriched.stream().collect(Collectors.partitioningBy(this::filterEnrichedItem));
            List enrichedValid = partitioned.getOrDefault(true, new ArrayList());
            List enrichedFiltered = partitioned.getOrDefault(false, new ArrayList());
            boolean validRow = enrichedValid.size() > 0;
            double sortValue = Double.MAX_VALUE;
            if (this.actualSorting() && validRow) {
                sortValue = this.getBestSortValue(enrichedValid);
            }
            row.put(this.sortColumn, sortValue);
            row.put(this.filterColumn, validRow ? 1 : -1);
            row.put(this.enrichedValidCol, JSON.json((Object)enrichedValid));
            row.put(this.enrichedFilteredCol, JSON.json((Object)enrichedFiltered));
        }
    }

    @UIModel
    public static class ObjectDetectionPredictedFilter
    extends ComputerVisionPredictedFilter {
        float minIOU = 0.0f;
        float minConfidence = 0.0f;
        @Nullable
        Category groundTruth;
        @Nullable
        Category detection;
        Sorting sorting;
    }

    static class ImageClassificationDerivedColumnsComputer
    extends PredictedDataService.BaseDerivedColumnsComputer
    implements DeepHubPredictionReportDerivedColumnsComputer {
        static final String SORT_COLUMN = "__dku__deephub_sorting_column";
        static final String FILTER_COLUMN = "__dku__deephub_filter_column";
        private final ImageClassificationPredictedFilter filter;
        private final PredictionMLTask.DeepHubPredictionMLTask mlTask;
        private ColumnFactory columnFactory;
        private Column targetCol;
        private Column predictionCol;
        private Column filterColumn;
        private Column sortColumn;

        ImageClassificationDerivedColumnsComputer(ImageClassificationPredictedFilter filter, MLTask mlTask) {
            this.filter = filter;
            this.mlTask = (PredictionMLTask.DeepHubPredictionMLTask)mlTask;
        }

        private void initializeColumnFactoryIfAbsent(ColumnFactory columnFactory) {
            if (this.columnFactory == null) {
                this.columnFactory = columnFactory;
                this.targetCol = columnFactory.column(this.mlTask.targetVariable);
                this.predictionCol = columnFactory.column(DeephubImagesDataService.DETECTION_COLUMN_NAME);
                this.filterColumn = columnFactory.column(FILTER_COLUMN);
                this.sortColumn = columnFactory.column(SORT_COLUMN);
            } else if (columnFactory != this.columnFactory) {
                throw new IllegalStateException("Different column factory already initialized");
            }
        }

        @Override
        public FilterRequest getFilterRequest() {
            if (this.filter == null) {
                return null;
            }
            FilterRequest filterRequest = new FilterRequest();
            FilterRequest.FilterElement filterElement = new FilterRequest.FilterElement();
            filterElement.active = true;
            filterElement.column = FILTER_COLUMN;
            filterElement.type = FilterRequest.FilterType.NUMERICAL_FACET;
            filterElement.minValue = 0.0;
            filterRequest.elements.add(filterElement);
            return filterRequest;
        }

        @Override
        public List<SerializedShakerScript.TableSorting> getSortingRequest() {
            ArrayList<SerializedShakerScript.TableSorting> ret = new ArrayList<SerializedShakerScript.TableSorting>();
            SerializedShakerScript.TableSorting tableSorting = new SerializedShakerScript.TableSorting();
            tableSorting.column = SORT_COLUMN;
            tableSorting.ascending = true;
            ret.add(tableSorting);
            return ret;
        }

        @Override
        public void compute(Row row, ColumnFactory columnFactory) throws IOException {
            this.initializeColumnFactoryIfAbsent(columnFactory);
            String actualVal = row.get(this.targetCol);
            String predictedVal = row.get(this.predictionCol);
            boolean validRow = this.isValidRow(actualVal, predictedVal);
            row.put(this.filterColumn, validRow ? 1 : -1);
            double sortValue = Double.MAX_VALUE;
            if (this.filter != null && validRow) {
                sortValue = Double.parseDouble(row.get(columnFactory.column("proba_" + predictedVal)));
                if (!this.filter.ascending) {
                    sortValue *= -1.0;
                }
            }
            row.put(this.sortColumn, sortValue);
        }

        private boolean isValidRow(String actualVal, String predictedVal) {
            if (predictedVal == null) {
                return false;
            }
            if (this.filter == null || this.filter.groundTruth == null && this.filter.detection == null) {
                return true;
            }
            if (this.filter.groundTruth == null) {
                return Strings.equals((String)this.filter.detection, (String)predictedVal);
            }
            if (this.filter.detection == null) {
                return Strings.equals((String)this.filter.groundTruth, (String)actualVal);
            }
            return Strings.equals((String)this.filter.detection, (String)predictedVal) && Strings.equals((String)this.filter.groundTruth, (String)actualVal);
        }
    }

    @UIModel
    public static class ImageClassificationPredictedFilter
    extends ComputerVisionPredictedFilter {
        @Nullable
        String groundTruth;
        @Nullable
        String detection;
        boolean ascending = false;
    }

    @PolyJSON(value={@Mapping(value=ObjectDetectionPredictedFilter.class, type="DEEP_HUB_IMAGE_OBJECT_DETECTION"), @Mapping(value=ImageClassificationPredictedFilter.class, type="DEEP_HUB_IMAGE_CLASSIFICATION")})
    public static abstract class ComputerVisionPredictedFilter {
    }

    static interface DeepHubPredictionReportDerivedColumnsComputer
    extends PredictedDataService.DerivedColumnsComputer {
        public FilterRequest getFilterRequest();

        public List<SerializedShakerScript.TableSorting> getSortingRequest();
    }

    static class Sorting {
        SortBy sortBy = SortBy.IOU;
        boolean ascending = false;

        Sorting() {
        }
    }

    static enum SortBy {
        IOU,
        CONFIDENCE;

    }

    static class Category {
        CategoryType type;
        String value;

        Category() {
        }
    }

    static enum CategoryType {
        ANY,
        NONE,
        ONE;

    }
}

