/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.labeling.objectdetection;

import com.dataiku.dip.labeling.BaseLabelingAnswer;
import com.dataiku.dip.labeling.objectdetection.BoundingBox;
import com.dataiku.dip.labeling.region.LabelingRegion;
import com.dataiku.dip.labeling.region.LabelingRegionDispatcher;
import com.dataiku.dip.labeling.region.RegionElement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class SimpleObjectDetectionRegionDispatcher
extends LabelingRegionDispatcher<BoundingBox> {
    private final double iouThreshold;

    public SimpleObjectDetectionRegionDispatcher(double iouThreshold) {
        this.iouThreshold = iouThreshold;
    }

    @Override
    protected List<LabelingRegion<BoundingBox>> buildRegions(List<? extends BaseLabelingAnswer> answers) {
        List<List<RegionElement<BoundingBox>>> regionElementsLists = IntStream.range(0, answers.size()).mapToObj(i -> this.regionElementsFromAnswers((BaseLabelingAnswer)answers.get(i), i)).collect(Collectors.toList());
        Map<String, RegionElement> boxesLeftToDispatch = regionElementsLists.stream().flatMap(Collection::stream).collect(Collectors.toMap(p -> p.id, p -> p));
        HashMap<String, Integer> boxIdToRegionIdx = new HashMap<String, Integer>();
        List<BoxesPairWithIOU> boxesPairsWithIOU = this.computeIoUs(regionElementsLists);
        boxesPairsWithIOU.sort(Collections.reverseOrder());
        ArrayList<LabelingRegion<BoundingBox>> regions = new ArrayList<LabelingRegion<BoundingBox>>();
        for (BoxesPairWithIOU boxesPair : boxesPairsWithIOU) {
            RegionElement<BoundingBox> boxToDispatch;
            boolean shouldDispatchBox1 = boxesLeftToDispatch.containsKey(boxesPair.box1.id);
            boolean shouldDispatchBox2 = boxesLeftToDispatch.containsKey(boxesPair.box2.id);
            if (shouldDispatchBox1 && shouldDispatchBox2) {
                LabelingRegion<BoundingBox> newRegion = new LabelingRegion<BoundingBox>(boxesPair.box1, boxesPair.box2);
                regions.add(newRegion);
                boxIdToRegionIdx.put(boxesPair.box1.id, regions.size() - 1);
                boxIdToRegionIdx.put(boxesPair.box2.id, regions.size() - 1);
                boxesLeftToDispatch.remove(boxesPair.box1.id);
                boxesLeftToDispatch.remove(boxesPair.box2.id);
                continue;
            }
            if (!shouldDispatchBox1 && !shouldDispatchBox2) continue;
            int regionIndex = (Integer)boxIdToRegionIdx.get(shouldDispatchBox1 ? boxesPair.box2.id : boxesPair.box1.id);
            LabelingRegion region = (LabelingRegion)regions.get(regionIndex);
            RegionElement<BoundingBox> regionElement = boxToDispatch = shouldDispatchBox1 ? boxesPair.box1 : boxesPair.box2;
            if (region.hasElementFrom(boxToDispatch.answerIdx)) continue;
            region.elements.add(boxToDispatch);
            boxesLeftToDispatch.remove(boxToDispatch.id);
            boxIdToRegionIdx.put(boxToDispatch.id, regionIndex);
        }
        regions.addAll(boxesLeftToDispatch.values().stream().map(LabelingRegion::new).collect(Collectors.toList()));
        return regions;
    }

    @Override
    public boolean computeConflict(LabelingRegion<BoundingBox> region, int totalNbAnswers) {
        if (region.elements.size() < totalNbAnswers) {
            return true;
        }
        for (int i = 0; i < region.elements.size() - 1; ++i) {
            for (int j = i + 1; j < region.elements.size(); ++j) {
                if (!this.isConflicting(region.elements.get(i), region.elements.get(j))) continue;
                return true;
            }
        }
        return false;
    }

    private List<BoxesPairWithIOU> computeIoUs(List<List<RegionElement<BoundingBox>>> regionElementsLists) {
        ArrayList<BoxesPairWithIOU> IoUs = new ArrayList<BoxesPairWithIOU>();
        for (int i = 0; i < regionElementsLists.size(); ++i) {
            for (int j = i + 1; j < regionElementsLists.size(); ++j) {
                List<RegionElement<BoundingBox>> regionElements1 = regionElementsLists.get(i);
                List<RegionElement<BoundingBox>> regionElements2 = regionElementsLists.get(j);
                for (RegionElement<BoundingBox> regionElement1 : regionElements1) {
                    for (RegionElement<BoundingBox> regionElement2 : regionElements2) {
                        double iou = ((BoundingBox)regionElement1.annotation).bbox.computeIoUWith(((BoundingBox)regionElement2.annotation).bbox);
                        if (!(iou > 0.0)) continue;
                        IoUs.add(new BoxesPairWithIOU(regionElement1, regionElement2, iou));
                    }
                }
            }
        }
        return IoUs;
    }

    private double computeIoU(RegionElement<BoundingBox> regionElement1, RegionElement<BoundingBox> regionElement2) {
        return ((BoundingBox)regionElement1.annotation).bbox.computeIoUWith(((BoundingBox)regionElement2.annotation).bbox);
    }

    private boolean isConflicting(RegionElement<BoundingBox> box1, RegionElement<BoundingBox> box2) {
        return !((BoundingBox)box1.annotation).category.equals(((BoundingBox)box2.annotation).category) || this.computeIoU(box1, box2) < this.iouThreshold;
    }

    private static class BoxesPairWithIOU
    implements Comparable<BoxesPairWithIOU> {
        public RegionElement<BoundingBox> box1;
        public RegionElement<BoundingBox> box2;
        public double IoU;

        BoxesPairWithIOU(RegionElement<BoundingBox> box1, RegionElement<BoundingBox> box2, double IoU) {
            this.box1 = box1;
            this.box2 = box2;
            this.IoU = IoU;
        }

        @Override
        public int compareTo(BoxesPairWithIOU pair) {
            return Double.compare(this.IoU, pair.IoU);
        }
    }
}

