/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.shaded.io.github.jbellis.jvector.pq;

import com.hazelcast.shaded.io.github.jbellis.jvector.annotations.VisibleForTesting;
import com.hazelcast.shaded.io.github.jbellis.jvector.disk.RandomAccessReader;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import com.hazelcast.shaded.io.github.jbellis.jvector.pq.CompressedVectors;
import com.hazelcast.shaded.io.github.jbellis.jvector.pq.ImmutablePQVectors;
import com.hazelcast.shaded.io.github.jbellis.jvector.pq.PQDecoder;
import com.hazelcast.shaded.io.github.jbellis.jvector.pq.ProductQuantization;
import com.hazelcast.shaded.io.github.jbellis.jvector.util.RamUsageEstimator;
import com.hazelcast.shaded.io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import com.hazelcast.shaded.io.github.jbellis.jvector.vector.VectorUtil;
import com.hazelcast.shaded.io.github.jbellis.jvector.vector.VectorizationProvider;
import com.hazelcast.shaded.io.github.jbellis.jvector.vector.types.ByteSequence;
import com.hazelcast.shaded.io.github.jbellis.jvector.vector.types.VectorFloat;
import com.hazelcast.shaded.io.github.jbellis.jvector.vector.types.VectorTypeSupport;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Objects;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.IntStream;

public abstract class PQVectors
implements CompressedVectors {
    private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
    static final int MAX_CHUNK_SIZE = 0x7FFFFFEF;
    final ProductQuantization pq;
    protected ByteSequence<?>[] compressedDataChunks;
    protected int vectorCount;
    protected int vectorsPerChunk;

    protected PQVectors(ProductQuantization pq) {
        this.pq = pq;
    }

    public static ImmutablePQVectors load(RandomAccessReader in) throws IOException {
        ProductQuantization pq = ProductQuantization.load(in);
        int vectorCount = in.readInt();
        int compressedDimension = in.readInt();
        int[] params = PQVectors.calculateChunkParameters(vectorCount, compressedDimension);
        int vectorsPerChunk = params[0];
        int totalChunks = params[1];
        int fullSizeChunks = params[2];
        int remainingVectors = params[3];
        ByteSequence[] chunks = new ByteSequence[totalChunks];
        int chunkBytes = vectorsPerChunk * compressedDimension;
        for (int i = 0; i < fullSizeChunks; ++i) {
            chunks[i] = vectorTypeSupport.readByteSequence(in, chunkBytes);
        }
        if (totalChunks > fullSizeChunks) {
            chunks[fullSizeChunks] = vectorTypeSupport.readByteSequence(in, remainingVectors * compressedDimension);
        }
        return new ImmutablePQVectors(pq, chunks, vectorCount, vectorsPerChunk);
    }

    @VisibleForTesting
    static int[] calculateChunkParameters(int vectorCount, int compressedDimension) {
        int vectorsPerChunk;
        if (vectorCount < 0) {
            throw new IllegalArgumentException("Invalid vector count " + vectorCount);
        }
        if (compressedDimension < 0) {
            throw new IllegalArgumentException("Invalid compressed dimension " + compressedDimension);
        }
        long totalSize = (long)vectorCount * (long)compressedDimension;
        int n = vectorsPerChunk = totalSize <= 0x7FFFFFEFL ? vectorCount : 0x7FFFFFEF / compressedDimension;
        if (vectorsPerChunk == 0) {
            throw new IllegalArgumentException("Compressed dimension " + compressedDimension + " too large for chunking");
        }
        int fullSizeChunks = vectorCount / vectorsPerChunk;
        int totalChunks = vectorCount % vectorsPerChunk == 0 ? fullSizeChunks : fullSizeChunks + 1;
        int remainingVectors = vectorCount % vectorsPerChunk;
        return new int[]{vectorsPerChunk, totalChunks, fullSizeChunks, remainingVectors};
    }

    public static PQVectors load(RandomAccessReader in, long offset) throws IOException {
        in.seek(offset);
        return PQVectors.load(in);
    }

    public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vectorCount, RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) {
        int compressedDimension = pq.compressedVectorSize();
        long totalSize = (long)vectorCount * (long)compressedDimension;
        int vectorsPerChunk = totalSize <= 0x7FFFFFEFL ? vectorCount : 0x7FFFFFEF / compressedDimension;
        int numChunks = vectorCount / vectorsPerChunk;
        ByteSequence[] chunks = new ByteSequence[numChunks];
        int chunkSize = vectorsPerChunk * compressedDimension;
        for (int i = 0; i < numChunks - 1; ++i) {
            chunks[i] = vectorTypeSupport.createByteSequence(chunkSize);
        }
        int remainingVectors = vectorCount - vectorsPerChunk * (numChunks - 1);
        chunks[numChunks - 1] = vectorTypeSupport.createByteSequence(remainingVectors * compressedDimension);
        ((ForkJoinTask)simdExecutor.submit(() -> IntStream.range(0, ravv.size()).parallel().forEach(ordinal -> {
            ByteSequence<?> slice = PQVectors.get(chunks, ordinal, vectorsPerChunk, pq.getSubspaceCount());
            VectorFloat<?> vector = ravv.getVector(ordinal);
            if (vector != null) {
                pq.encodeTo(vector, slice);
            } else {
                slice.zero();
            }
        }))).join();
        return new ImmutablePQVectors(pq, chunks, vectorCount, vectorsPerChunk);
    }

    @Override
    public int count() {
        return this.vectorCount;
    }

    @Override
    public void write(DataOutput out, int version) throws IOException {
        this.pq.write(out, version);
        out.writeInt(this.vectorCount);
        out.writeInt(this.pq.getSubspaceCount());
        for (int i = 0; i < this.validChunkCount(); ++i) {
            vectorTypeSupport.writeByteSequence(out, this.compressedDataChunks[i]);
        }
    }

    protected abstract int validChunkCount();

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        PQVectors that = (PQVectors)o;
        if (!Objects.equals(this.pq, that.pq)) {
            return false;
        }
        if (this.count() != that.count()) {
            return false;
        }
        for (int i = 0; i < this.count(); ++i) {
            ByteSequence<?> thatNode;
            ByteSequence<?> thisNode = this.get(i);
            if (thisNode.equals(thatNode = that.get(i))) continue;
            return false;
        }
        return true;
    }

    public int hashCode() {
        int result = 1;
        result = 31 * result + this.pq.hashCode();
        result = 31 * result + this.count();
        for (int i = 0; i < this.count(); ++i) {
            result = 31 * result + this.get(i).hashCode();
        }
        return result;
    }

    @Override
    public ScoreFunction.ApproximateScoreFunction precomputedScoreFunctionFor(VectorFloat<?> q, VectorSimilarityFunction similarityFunction) {
        switch (similarityFunction) {
            case DOT_PRODUCT: {
                return new PQDecoder.DotProductDecoder(this, q);
            }
            case EUCLIDEAN: {
                return new PQDecoder.EuclideanDecoder(this, q);
            }
            case COSINE: {
                return new PQDecoder.CosineDecoder(this, q);
            }
        }
        throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf((Object)similarityFunction));
    }

    @Override
    public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q, VectorSimilarityFunction similarityFunction) {
        VectorFloat<?> centeredQuery = this.pq.globalCentroid == null ? q : VectorUtil.sub(q, this.pq.globalCentroid);
        switch (similarityFunction) {
            case DOT_PRODUCT: {
                return node2 -> {
                    ByteSequence<?> encoded = this.get(node2);
                    float dp = 0.0f;
                    for (int m = 0; m < this.pq.getSubspaceCount(); ++m) {
                        int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
                        int centroidLength = this.pq.subvectorSizesAndOffsets[m][0];
                        int centroidOffset = this.pq.subvectorSizesAndOffsets[m][1];
                        dp += VectorUtil.dotProduct(this.pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
                    }
                    return (1.0f + dp) / 2.0f;
                };
            }
            case COSINE: {
                float norm1 = VectorUtil.dotProduct(centeredQuery, centeredQuery);
                return node2 -> {
                    ByteSequence<?> encoded = this.get(node2);
                    float sum = 0.0f;
                    float norm2 = 0.0f;
                    for (int m = 0; m < this.pq.getSubspaceCount(); ++m) {
                        int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
                        int centroidLength = this.pq.subvectorSizesAndOffsets[m][0];
                        int centroidOffset = this.pq.subvectorSizesAndOffsets[m][1];
                        int codebookOffset = centroidIndex * centroidLength;
                        sum += VectorUtil.dotProduct(this.pq.codebooks[m], codebookOffset, centeredQuery, centroidOffset, centroidLength);
                        norm2 += VectorUtil.dotProduct(this.pq.codebooks[m], codebookOffset, this.pq.codebooks[m], codebookOffset, centroidLength);
                    }
                    float cosine = sum / (float)Math.sqrt(norm1 * norm2);
                    return (1.0f + cosine) / 2.0f;
                };
            }
            case EUCLIDEAN: {
                return node2 -> {
                    ByteSequence<?> encoded = this.get(node2);
                    float sum = 0.0f;
                    for (int m = 0; m < this.pq.getSubspaceCount(); ++m) {
                        int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
                        int centroidLength = this.pq.subvectorSizesAndOffsets[m][0];
                        int centroidOffset = this.pq.subvectorSizesAndOffsets[m][1];
                        sum += VectorUtil.squareL2Distance(this.pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength);
                    }
                    return 1.0f / (1.0f + sum);
                };
            }
        }
        throw new IllegalArgumentException("Unsupported similarity function " + String.valueOf((Object)similarityFunction));
    }

    public ByteSequence<?> get(int ordinal) {
        if (ordinal < 0 || ordinal >= this.vectorCount) {
            throw new IndexOutOfBoundsException("Ordinal " + ordinal + " out of bounds for vector count " + this.vectorCount);
        }
        return PQVectors.get(this.compressedDataChunks, ordinal, this.vectorsPerChunk, this.pq.getSubspaceCount());
    }

    static ByteSequence<?> get(ByteSequence<?>[] chunks, int ordinal, int vectorsPerChunk, int subspaceCount) {
        int chunkIndex = ordinal / vectorsPerChunk;
        int vectorIndexInChunk = ordinal % vectorsPerChunk;
        int start = vectorIndexInChunk * subspaceCount;
        return chunks[chunkIndex].slice(start, subspaceCount);
    }

    VectorFloat<?> reusablePartialSums() {
        return this.pq.reusablePartialSums();
    }

    AtomicReference<VectorFloat<?>> partialSquaredMagnitudes() {
        return this.pq.partialSquaredMagnitudes();
    }

    @Override
    public int getOriginalSize() {
        return this.pq.originalDimension * 4;
    }

    @Override
    public int getCompressedSize() {
        return this.pq.compressedVectorSize();
    }

    public ProductQuantization getCompressor() {
        return this.pq;
    }

    @Override
    public long ramBytesUsed() {
        int REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF;
        int OH_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER;
        int AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;
        long codebooksSize = this.pq.ramBytesUsed();
        long chunksArraySize = (long)(OH_BYTES + AH_BYTES) + (long)this.validChunkCount() * (long)REF_BYTES;
        long dataSize = 0L;
        for (int i = 0; i < this.validChunkCount(); ++i) {
            dataSize += this.compressedDataChunks[i].ramBytesUsed();
        }
        return codebooksSize + chunksArraySize + dataSize;
    }

    public String toString() {
        return "PQVectors{pq=" + String.valueOf(this.pq) + ", count=" + this.vectorCount + "}";
    }
}

