/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.vector.impl.storage;

import com.hazelcast.config.vector.Metric;
import com.hazelcast.core.HazelcastException;
import com.hazelcast.internal.memory.Measurable;
import com.hazelcast.internal.serialization.Data;
import com.hazelcast.internal.util.JVMUtil;
import com.hazelcast.internal.util.ThreadUtil;
import com.hazelcast.logging.ILogger;
import com.hazelcast.nio.ObjectDataOutput;
import com.hazelcast.shaded.io.github.jbellis.jvector.annotations.VisibleForTesting;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.GraphIndexBuilder;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.GraphSearcher;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.HazelcastGraphIndexViewProvider;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.OnHeapGraphIndex;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.SearchResult;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.similarity.SearchScoreProvider;
import com.hazelcast.shaded.io.github.jbellis.jvector.util.BitSet;
import com.hazelcast.shaded.io.github.jbellis.jvector.util.Bits;
import com.hazelcast.shaded.io.github.jbellis.jvector.vector.types.VectorFloat;
import com.hazelcast.spi.impl.InternalCompletableFuture;
import com.hazelcast.spi.impl.NodeEngine;
import com.hazelcast.vector.IndexMutationDisallowedException;
import com.hazelcast.vector.SearchResults;
import com.hazelcast.vector.impl.stats.VectorIndexStats;
import com.hazelcast.vector.impl.stats.VectorIndexStatsImpl;
import com.hazelcast.vector.impl.storage.ConcurrentModificationPolicy;
import com.hazelcast.vector.impl.storage.ReplicationStateHolder;
import com.hazelcast.vector.impl.storage.UpdatableVectorsSource;
import com.hazelcast.vector.impl.storage.VectorCollectionStorage;
import com.hazelcast.vector.impl.storage.graph.HazelcastBuildScoreProvider;
import com.hazelcast.vector.impl.storage.graph.HazelcastBuiltinVectorSimilarityFunction;
import com.hazelcast.vector.impl.storage.graph.HazelcastVectorSimilarityFunction;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import javax.annotation.Nonnull;

public abstract class AbstractVectorIndex
implements Measurable {
    static final long FIXED_HEAP_BYTES_USED = (long)JVMUtil.OBJECT_HEADER_SIZE + 7L * (long)JVMUtil.REFERENCE_COST_IN_BYTES + 20L + (long)JVMUtil.OBJECT_HEADER_SIZE + (long)JVMUtil.REFERENCE_COST_IN_BYTES + VectorIndexStatsImpl.FIXED_HEAP_BYTES_USED;
    protected int idGenerator;
    volatile UpdatableVectorsSource vectorsSupplier;
    @VisibleForTesting
    GraphIndexBuilder indexBuilder;
    final String indexName;
    private final int maxDegree;
    private final int efConstruction;
    private final int dimensions;
    private final HazelcastVectorSimilarityFunction similarityFunction;
    private volatile SearcherPool searcherPool;
    private volatile int indexVersion;
    private final AtomicReference<OptimizationState> indexOptimizationState = new AtomicReference();
    private final VectorIndexStatsImpl stats = new VectorIndexStatsImpl();

    AbstractVectorIndex(String indexName, Metric metric, int maxDegree, int efConstruction, int dimensions) {
        this.indexName = indexName;
        this.maxDegree = maxDegree;
        this.efConstruction = efConstruction;
        this.dimensions = dimensions;
        this.similarityFunction = this.asVectorSimilarityFunction(metric);
        this.vectorsSupplier = new UpdatableVectorsSource(dimensions);
        this.indexBuilder = this.createIndexBuilder(this.vectorsSupplier);
        this.searcherPool = this.createSearcherPool(this.indexBuilder);
    }

    protected GraphIndexBuilder createIndexBuilder(RandomAccessVectorValues vectorsSource) {
        return new GraphIndexBuilder(new HazelcastBuildScoreProvider(vectorsSource, this.similarityFunction), vectorsSource.dimension(), this.maxDegree, this.efConstruction, 1.2f, 1.4f);
    }

    private SearcherPool createSearcherPool(GraphIndexBuilder indexBuilder) {
        return new SearcherPool(indexBuilder.getGraph());
    }

    private HazelcastVectorSimilarityFunction asVectorSimilarityFunction(Metric m) {
        return switch (m) {
            default -> throw new IncompatibleClassChangeError();
            case Metric.DOT -> HazelcastBuiltinVectorSimilarityFunction.DOT_PRODUCT;
            case Metric.COSINE -> HazelcastBuiltinVectorSimilarityFunction.cosine();
            case Metric.EUCLIDEAN -> HazelcastBuiltinVectorSimilarityFunction.EUCLIDEAN;
        };
    }

    public int getEfConstruction() {
        return this.efConstruction;
    }

    public VectorFloat<?> put(Data key, VectorFloat<?> vector, boolean returnPrevious) {
        this.checkMutatingOperationAllowed();
        this.validateVector(vector);
        VectorFloat<?> previousVector = this.putInternal(key, vector);
        if (returnPrevious && previousVector != null) {
            return previousVector;
        }
        return null;
    }

    public void put(Data key, VectorFloat<?> vector) {
        this.checkMutatingOperationAllowed();
        this.validateVector(vector);
        this.putInternal(key, vector);
    }

    public VectorFloat<?> get(Data key) {
        Integer nodeId = this.getNodeIdByKey(key);
        if (nodeId == null) {
            throw new IllegalStateException("The entry with the requested key does not exist.");
        }
        return this.vectorsSupplier.getVector(nodeId);
    }

    public boolean delete(Data key) {
        this.checkMutatingOperationAllowed();
        return this.deleteInternal(key);
    }

    public void cleanup() {
        OptimizationState optimizationState = this.indexOptimizationState.get();
        if (optimizationState == null) {
            throw new IllegalStateException("The index must be locked for mutation before the optimization process begins.");
        }
        if (optimizationState.isDone()) {
            throw new IllegalStateException("Optimization already done, lock is not reusable");
        }
        try {
            this.cleanupInternal(CleanupMode.FULL);
            optimizationState.complete();
        }
        catch (Exception e) {
            optimizationState.completeExceptionally(e);
        }
    }

    private void cleanupInternal(CleanupMode mode) {
        ArrayList<Integer> candidateToBeRemoved = new ArrayList<Integer>();
        BitSet deletedNodes = this.getRemovedNodes();
        int i = deletedNodes.nextSetBit(0);
        while (i != Integer.MAX_VALUE) {
            candidateToBeRemoved.add(i);
            i = deletedNodes.nextSetBit(i + 1);
        }
        switch (mode) {
            case DELETED_NODES: {
                this.indexBuilder.removeDeletedNodes();
                break;
            }
            default: {
                this.indexBuilder.cleanup();
            }
        }
        candidateToBeRemoved.forEach(this.vectorsSupplier::remove);
    }

    int getSearchLogicalTime() {
        return this.indexVersion;
    }

    SearchResults<Data, Data> search(VectorFloat<?> vector, int topK) {
        return this.search(vector, topK, topK, ConcurrentModificationPolicy.THROW);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public SearchResults<Data, Data> search(VectorFloat<?> vector, int topK, int rerankK, ConcurrentModificationPolicy cmPolicy) {
        SearchResult searchResult;
        if (topK <= 0) {
            throw new IllegalArgumentException("topK must be > 0");
        }
        SearchScoreProvider scoreProvider = this.createSearchScoreProvider(vector, cmPolicy);
        SearcherPool searcherPoolRef = this.searcherPool;
        GraphSearcher searcher = searcherPoolRef.acquire();
        try {
            searchResult = searcher.search(scoreProvider, topK, rerankK, 0.0f, 0.0f, Bits.ALL);
        }
        finally {
            searcherPoolRef.release(searcher);
        }
        this.stats.incrementQueryCount();
        this.stats.addVisitedNodes(searchResult.getVisitedCount());
        return this.toDataSearchResults(searchResult, topK, cmPolicy);
    }

    private SearchScoreProvider createSearchScoreProvider(final VectorFloat<?> vector, final ConcurrentModificationPolicy cmPolicy) {
        ScoreFunction.ExactScoreFunction sf = new ScoreFunction.ExactScoreFunction(){

            @Override
            public float similarityTo(int node2) {
                VectorFloat<?> vector2 = AbstractVectorIndex.this.vectorsSupplier.getVector(node2);
                if (vector2 == null) {
                    cmPolicy.action();
                    return Float.NEGATIVE_INFINITY;
                }
                return AbstractVectorIndex.this.similarityFunction.compare(vector, vector2);
            }
        };
        return new SearchScoreProvider(sf);
    }

    public void lockIndexMutation(UUID uuid) {
        OptimizationState currentOptimization = this.indexOptimizationState.compareAndExchange(null, new OptimizationState(uuid));
        if (currentOptimization != null) {
            throw new IndexMutationDisallowedException("Index optimization process with uuid = " + String.valueOf(currentOptimization.uuid()) + " is in progress.");
        }
    }

    public void unlockIndexMutation() {
        OptimizationState currentState = this.indexOptimizationState.get();
        if (currentState != null && !currentState.isDone()) {
            currentState.completeExceptionally(new HazelcastException("Index optimization terminated by unlocking"));
        }
        if (currentState == null || !this.indexOptimizationState.compareAndSet(currentState, null)) {
            throw new IllegalStateException("Failed to unlock a collection that was not locked.");
        }
    }

    public boolean isMutatingOperationAllowed() {
        return this.indexOptimizationState.get() == null;
    }

    public void checkMutatingOperationAllowed() {
        if (!this.isMutatingOperationAllowed()) {
            throw new IndexMutationDisallowedException("Index optimization process is in progress.");
        }
    }

    public VectorIndexStats getStats() {
        return this.stats;
    }

    protected BitSet getRemovedNodes() {
        return this.indexBuilder.getGraph().getDeletedNodes();
    }

    protected abstract boolean deleteInternal(Data var1);

    protected abstract VectorFloat<?> putInternal(Data var1, VectorFloat<?> var2);

    protected abstract SearchResults<Data, Data> toDataSearchResults(SearchResult var1, int var2, ConcurrentModificationPolicy var3);

    protected abstract Integer getNodeIdByKey(Data var1);

    protected abstract boolean hasLiveNodes();

    protected boolean maybeRecreateIndex() {
        if (this.hasLiveNodes()) {
            return false;
        }
        this.cleanupInternal(CleanupMode.FULL);
        this.idGenerator = 0;
        ++this.indexVersion;
        return true;
    }

    private void validateVector(VectorFloat<?> vector) {
        this.validateDimension(vector.length());
    }

    void validateVector(float[] vector) {
        this.validateDimension(vector.length);
    }

    private void validateDimension(int length) {
        if (length != this.dimensions) {
            throw new IllegalArgumentException("Vector length " + length + " different than expected for index " + this.indexName);
        }
    }

    @Override
    public long heapBytesUsed() {
        return FIXED_HEAP_BYTES_USED + this.vectorsSupplier.heapBytesUsed() + this.indexBuilder.getGraph().ramBytesUsed();
    }

    void prepareForMigration(VectorCollectionStorage vectorCollectionStorage, NodeEngine nodeEngine) {
        assert (!ThreadUtil.isRunningOnPartitionThread()) : "Preparation should be offloaded";
        assert (nodeEngine.getPartitionService().getPartition(vectorCollectionStorage.getPartitionId()).isMigrating()) : "Partition should be marked as migrating";
        OptimizationState optimizationState = this.indexOptimizationState.get();
        if (optimizationState != null) {
            ILogger logger = nodeEngine.getLogger(AbstractVectorIndex.class);
            logger.info("Index optimization in progress for partitionId=" + vectorCollectionStorage.getPartitionId() + " index name=" + this.indexName + ", waiting for completion before continuing migration");
            optimizationState.join();
            return;
        }
        this.cleanupInternal(CleanupMode.DELETED_NODES);
    }

    abstract void writeKeyToNodeIdMapping(ObjectDataOutput var1, Data var2) throws IOException;

    void resetState(ReplicationStateHolder.CollectionReplicationStateHolder.IndexReplicationStateHolder indexState) {
        this.checkMutatingOperationAllowed();
        this.idGenerator = indexState.idGeneratorState;
        this.indexBuilder = indexState.index.indexBuilder;
        this.searcherPool = this.createSearcherPool(this.indexBuilder);
        this.vectorsSupplier = indexState.vectorsSupplier;
    }

    private static class SearcherPool {
        private final AtomicReference<GraphSearcher> pool;
        private final Supplier<GraphSearcher> graphSearcherSupplier = () -> new GraphSearcher(new HazelcastGraphIndexViewProvider(graph));

        SearcherPool(OnHeapGraphIndex graph) {
            this.pool = new AtomicReference<GraphSearcher>(this.graphSearcherSupplier.get());
        }

        public GraphSearcher acquire() {
            GraphSearcher pooledSearcher = this.pool.getAndSet(null);
            return pooledSearcher != null ? pooledSearcher : this.graphSearcherSupplier.get();
        }

        public void release(GraphSearcher searcher) {
            this.pool.compareAndSet(null, searcher);
        }
    }

    private record OptimizationState(@Nonnull UUID uuid, @Nonnull InternalCompletableFuture<Void> indexOptimizationFuture) {
        OptimizationState(@Nonnull UUID uuid) {
            this(Objects.requireNonNull(uuid), new InternalCompletableFuture<Void>());
        }

        boolean isDone() {
            return this.indexOptimizationFuture.isDone();
        }

        void complete() {
            this.indexOptimizationFuture.complete(null);
        }

        void completeExceptionally(Exception e) {
            this.indexOptimizationFuture.completeExceptionally(e);
        }

        void join() {
            this.indexOptimizationFuture.joinInternal();
        }
    }

    private static enum CleanupMode {
        FULL,
        DELETED_NODES;

    }
}

