/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dss.shadelib.org.apache.lucene.internal.vectorization;

import com.dataiku.dss.shadelib.org.apache.lucene.index.VectorSimilarityFunction;
import com.dataiku.dss.shadelib.org.apache.lucene.internal.vectorization.PanamaVectorUtilSupport;
import com.dataiku.dss.shadelib.org.apache.lucene.store.FilterIndexInput;
import com.dataiku.dss.shadelib.org.apache.lucene.store.IndexInput;
import com.dataiku.dss.shadelib.org.apache.lucene.store.MemorySegmentAccessInput;
import com.dataiku.dss.shadelib.org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import com.dataiku.dss.shadelib.org.apache.lucene.util.hnsw.RandomVectorScorer;
import com.dataiku.dss.shadelib.org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.util.Optional;

public abstract class Lucene99MemorySegmentByteVectorScorerSupplier
implements RandomVectorScorerSupplier {
    final int vectorByteSize;
    final int maxOrd;
    final MemorySegmentAccessInput input;
    final RandomAccessVectorValues values;
    byte[] scratch1;
    byte[] scratch2;

    static Optional<RandomVectorScorerSupplier> create(VectorSimilarityFunction type, IndexInput input, RandomAccessVectorValues values) {
        if (!((input = FilterIndexInput.unwrapOnlyTest(input)) instanceof MemorySegmentAccessInput)) {
            return Optional.empty();
        }
        MemorySegmentAccessInput msInput = (MemorySegmentAccessInput)((Object)input);
        Lucene99MemorySegmentByteVectorScorerSupplier.checkInvariants(values.size(), values.getVectorByteLength(), input);
        if (type == VectorSimilarityFunction.COSINE) {
            return Optional.of(new CosineSupplier(msInput, values));
        }
        if (type == VectorSimilarityFunction.DOT_PRODUCT) {
            return Optional.of(new DotProductSupplier(msInput, values));
        }
        if (type == VectorSimilarityFunction.EUCLIDEAN) {
            return Optional.of(new EuclideanSupplier(msInput, values));
        }
        if (type == VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT) {
            return Optional.of(new MaxInnerProductSupplier(msInput, values));
        }
        throw new IllegalArgumentException("unknown type: " + type);
    }

    Lucene99MemorySegmentByteVectorScorerSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
        this.input = input;
        this.values = values;
        this.vectorByteSize = values.getVectorByteLength();
        this.maxOrd = values.size();
    }

    static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) {
        if (input.length() < (long)vectorByteLength * (long)maxOrd) {
            throw new IllegalArgumentException("input length is less than expected vector data");
        }
    }

    final void checkOrdinal(int ord) {
        if (ord < 0 || ord >= this.maxOrd) {
            throw new IllegalArgumentException("illegal ordinal: " + ord);
        }
    }

    final MemorySegment getFirstSegment(int ord) throws IOException {
        long byteOffset = (long)ord * (long)this.vectorByteSize;
        MemorySegment seg = this.input.segmentSliceOrNull(byteOffset, this.vectorByteSize);
        if (seg == null) {
            if (this.scratch1 == null) {
                this.scratch1 = new byte[this.vectorByteSize];
            }
            this.input.readBytes(byteOffset, this.scratch1, 0, this.vectorByteSize);
            seg = MemorySegment.ofArray(this.scratch1);
        }
        return seg;
    }

    final MemorySegment getSecondSegment(int ord) throws IOException {
        long byteOffset = (long)ord * (long)this.vectorByteSize;
        MemorySegment seg = this.input.segmentSliceOrNull(byteOffset, this.vectorByteSize);
        if (seg == null) {
            if (this.scratch2 == null) {
                this.scratch2 = new byte[this.vectorByteSize];
            }
            this.input.readBytes(byteOffset, this.scratch2, 0, this.vectorByteSize);
            seg = MemorySegment.ofArray(this.scratch2);
        }
        return seg;
    }

    static final class MaxInnerProductSupplier
    extends Lucene99MemorySegmentByteVectorScorerSupplier {
        MaxInnerProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
            super(input, values);
        }

        @Override
        public RandomVectorScorer scorer(final int ord) {
            this.checkOrdinal(ord);
            return new RandomVectorScorer.AbstractRandomVectorScorer(this.values){

                @Override
                public float score(int node) throws IOException {
                    this.checkOrdinal(node);
                    float raw = PanamaVectorUtilSupport.dotProduct(this.getFirstSegment(ord), this.getSecondSegment(node));
                    if (raw < 0.0f) {
                        return 1.0f / (1.0f + -1.0f * raw);
                    }
                    return raw + 1.0f;
                }
            };
        }

        @Override
        public MaxInnerProductSupplier copy() throws IOException {
            return new MaxInnerProductSupplier(this.input.clone(), this.values);
        }
    }

    static final class EuclideanSupplier
    extends Lucene99MemorySegmentByteVectorScorerSupplier {
        EuclideanSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
            super(input, values);
        }

        @Override
        public RandomVectorScorer scorer(final int ord) {
            this.checkOrdinal(ord);
            return new RandomVectorScorer.AbstractRandomVectorScorer(this.values){

                @Override
                public float score(int node) throws IOException {
                    this.checkOrdinal(node);
                    float raw = PanamaVectorUtilSupport.squareDistance(this.getFirstSegment(ord), this.getSecondSegment(node));
                    return 1.0f / (1.0f + raw);
                }
            };
        }

        @Override
        public EuclideanSupplier copy() throws IOException {
            return new EuclideanSupplier(this.input.clone(), this.values);
        }
    }

    static final class DotProductSupplier
    extends Lucene99MemorySegmentByteVectorScorerSupplier {
        DotProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
            super(input, values);
        }

        @Override
        public RandomVectorScorer scorer(final int ord) {
            this.checkOrdinal(ord);
            return new RandomVectorScorer.AbstractRandomVectorScorer(this.values){

                @Override
                public float score(int node) throws IOException {
                    this.checkOrdinal(node);
                    float raw = PanamaVectorUtilSupport.dotProduct(this.getFirstSegment(ord), this.getSecondSegment(node));
                    return 0.5f + raw / (float)(values.dimension() * 32768);
                }
            };
        }

        @Override
        public DotProductSupplier copy() throws IOException {
            return new DotProductSupplier(this.input.clone(), this.values);
        }
    }

    static final class CosineSupplier
    extends Lucene99MemorySegmentByteVectorScorerSupplier {
        CosineSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
            super(input, values);
        }

        @Override
        public RandomVectorScorer scorer(final int ord) {
            this.checkOrdinal(ord);
            return new RandomVectorScorer.AbstractRandomVectorScorer(this.values){

                @Override
                public float score(int node) throws IOException {
                    this.checkOrdinal(node);
                    float raw = PanamaVectorUtilSupport.cosine(this.getFirstSegment(ord), this.getSecondSegment(node));
                    return (1.0f + raw) / 2.0f;
                }
            };
        }

        @Override
        public CosineSupplier copy() throws IOException {
            return new CosineSupplier(this.input.clone(), this.values);
        }
    }
}

