/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dss.shadelib.org.apache.lucene.util;

import com.dataiku.dss.shadelib.org.apache.lucene.internal.vectorization.VectorUtilSupport;
import com.dataiku.dss.shadelib.org.apache.lucene.internal.vectorization.VectorizationProvider;
import com.dataiku.dss.shadelib.org.apache.lucene.util.BitUtil;
import com.dataiku.dss.shadelib.org.apache.lucene.util.Constants;

public final class VectorUtil {
    private static final float EPSILON = 1.0E-4f;
    private static final VectorUtilSupport IMPL = VectorizationProvider.getInstance().getVectorUtilSupport();
    static final boolean XOR_BIT_COUNT_STRIDE_AS_INT = Constants.OS_ARCH.equals("aarch64");

    private VectorUtil() {
    }

    public static float dotProduct(float[] a, float[] b) {
        if (a.length != b.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
        }
        float r = IMPL.dotProduct(a, b);
        assert (Float.isFinite(r));
        return r;
    }

    public static float cosine(float[] a, float[] b) {
        if (a.length != b.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
        }
        float r = IMPL.cosine(a, b);
        assert (Float.isFinite(r));
        return r;
    }

    public static float cosine(byte[] a, byte[] b) {
        if (a.length != b.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
        }
        return IMPL.cosine(a, b);
    }

    public static float squareDistance(float[] a, float[] b) {
        if (a.length != b.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
        }
        float r = IMPL.squareDistance(a, b);
        assert (Float.isFinite(r));
        return r;
    }

    public static int squareDistance(byte[] a, byte[] b) {
        if (a.length != b.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
        }
        return IMPL.squareDistance(a, b);
    }

    public static float[] l2normalize(float[] v) {
        VectorUtil.l2normalize(v, true);
        return v;
    }

    public static boolean isUnitVector(float[] v) {
        double l1norm = IMPL.dotProduct(v, v);
        return Math.abs(l1norm - 1.0) <= (double)1.0E-4f;
    }

    public static float[] l2normalize(float[] v, boolean throwOnZero) {
        double l1norm = IMPL.dotProduct(v, v);
        if (l1norm == 0.0) {
            if (throwOnZero) {
                throw new IllegalArgumentException("Cannot normalize a zero-length vector");
            }
            return v;
        }
        if (Math.abs(l1norm - 1.0) <= (double)1.0E-4f) {
            return v;
        }
        int dim = v.length;
        double l2norm = Math.sqrt(l1norm);
        int i = 0;
        while (i < dim) {
            int n = i++;
            v[n] = v[n] / (float)l2norm;
        }
        return v;
    }

    public static void add(float[] u, float[] v) {
        for (int i = 0; i < u.length; ++i) {
            int n = i;
            u[n] = u[n] + v[i];
        }
    }

    public static int dotProduct(byte[] a, byte[] b) {
        if (a.length != b.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
        }
        return IMPL.dotProduct(a, b);
    }

    public static int int4DotProduct(byte[] a, byte[] b) {
        if (a.length != b.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
        }
        return IMPL.int4DotProduct(a, false, b, false);
    }

    public static int int4DotProductPacked(byte[] unpacked, byte[] packed) {
        if (packed.length != unpacked.length + 1 >> 1) {
            throw new IllegalArgumentException("vector dimensions differ: " + unpacked.length + "!= 2 * " + packed.length);
        }
        return IMPL.int4DotProduct(unpacked, false, packed, true);
    }

    public static int xorBitCount(byte[] a, byte[] b) {
        if (a.length != b.length) {
            throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
        }
        if (XOR_BIT_COUNT_STRIDE_AS_INT) {
            return VectorUtil.xorBitCountInt(a, b);
        }
        return VectorUtil.xorBitCountLong(a, b);
    }

    static int xorBitCountInt(byte[] a, byte[] b) {
        int i;
        int distance = 0;
        int upperBound = a.length & 0xFFFFFFFC;
        for (i = 0; i < upperBound; i += 4) {
            distance += Integer.bitCount(BitUtil.VH_NATIVE_INT.get(a, i) ^ BitUtil.VH_NATIVE_INT.get(b, i));
        }
        while (i < a.length) {
            distance += Integer.bitCount((a[i] ^ b[i]) & 0xFF);
            ++i;
        }
        return distance;
    }

    static int xorBitCountLong(byte[] a, byte[] b) {
        int i;
        int distance = 0;
        int upperBound = a.length & 0xFFFFFFF8;
        for (i = 0; i < upperBound; i += 8) {
            distance += Long.bitCount(BitUtil.VH_NATIVE_LONG.get(a, i) ^ BitUtil.VH_NATIVE_LONG.get(b, i));
        }
        while (i < a.length) {
            distance += Integer.bitCount((a[i] ^ b[i]) & 0xFF);
            ++i;
        }
        return distance;
    }

    public static float dotProductScore(byte[] a, byte[] b) {
        float denom = a.length * 32768;
        return 0.5f + (float)VectorUtil.dotProduct(a, b) / denom;
    }

    public static float scaleMaxInnerProductScore(float vectorDotProductSimilarity) {
        if (vectorDotProductSimilarity < 0.0f) {
            return 1.0f / (1.0f + -1.0f * vectorDotProductSimilarity);
        }
        return vectorDotProductSimilarity + 1.0f;
    }

    public static float[] checkFinite(float[] v) {
        for (int i = 0; i < v.length; ++i) {
            if (Float.isFinite(v[i])) continue;
            throw new IllegalArgumentException("non-finite value at vector[" + i + "]=" + v[i]);
        }
        return v;
    }
}

