/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.dataflow.exec.fuzzyjoin.builtinengine.verifier;

import com.dataiku.dip.dataflow.exec.fuzzyjoin.FuzzyJoinRecipePayloadParams;
import com.dataiku.dip.dataflow.exec.fuzzyjoin.builtinengine.verifier.Matcher;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public class CosineMatcher
extends Matcher {
    public CosineMatcher(double threshold, Integer relativeTo, boolean debugMode) {
        super(threshold, relativeTo, debugMode);
    }

    @Override
    protected boolean isValidInputs(String query, String candidate) {
        return query != null && candidate != null;
    }

    @Override
    public double computeDistance(String query, String candidate) {
        Map<Character, Integer> queryVector = this.countCharacters(query);
        Map<Character, Integer> candidateVector = this.countCharacters(candidate);
        return 1.0 - this.cosineSimilarity(queryVector, candidateVector);
    }

    @Override
    protected FuzzyJoinRecipePayloadParams.DistanceType getDistanceType() {
        return FuzzyJoinRecipePayloadParams.DistanceType.COSINE;
    }

    private Map<Character, Integer> countCharacters(String s) {
        HashMap<Character, Integer> counter = new HashMap<Character, Integer>();
        for (char c2 : s.toCharArray()) {
            int count = counter.containsKey(Character.valueOf(c2)) ? (Integer)counter.get(Character.valueOf(c2)) + 1 : 1;
            counter.put(Character.valueOf(c2), count);
        }
        return counter;
    }

    private double cosineSimilarity(Map<Character, Integer> vectorA, Map<Character, Integer> vectorB) {
        Set<Character> intersection = this.getIntersection(vectorA, vectorB);
        long dotProduct = this.dot(vectorA, vectorB, intersection);
        return this.normaliseProduct(vectorA, vectorB, dotProduct);
    }

    private Set<Character> getIntersection(Map<Character, Integer> vectorA, Map<Character, Integer> vectorB) {
        HashSet<Character> intersection = new HashSet<Character>(vectorA.keySet());
        intersection.retainAll(vectorB.keySet());
        return intersection;
    }

    private long dot(Map<Character, Integer> leftVector, Map<Character, Integer> rightVector, Set<Character> intersection) {
        long dotProduct = 0L;
        for (Character key : intersection) {
            dotProduct += (long)(leftVector.get(key) * rightVector.get(key));
        }
        return dotProduct;
    }

    private double normaliseProduct(Map<Character, Integer> vectorA, Map<Character, Integer> vectorB, long dotProduct) {
        double normA = this.computeVectorNorm(vectorA.values());
        double normB = this.computeVectorNorm(vectorB.values());
        if (normA <= 0.0 || normB <= 0.0) {
            return 0.0;
        }
        return (double)dotProduct / (normA * normB);
    }

    private double computeVectorNorm(Collection<Integer> vector) {
        double x = 0.0;
        for (int count : vector) {
            x += Math.pow(count, 2.0);
        }
        return Math.sqrt(x);
    }
}

