/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.shaded.io.github.jbellis.jvector.graph;

import com.hazelcast.shaded.io.github.jbellis.jvector.annotations.VisibleForTesting;
import com.hazelcast.shaded.io.github.jbellis.jvector.disk.RandomAccessReader;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.ConcurrentNeighborMap;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.GraphSearcher;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.NodeArray;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.NodesIterator;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.OnHeapGraphIndex;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.SearchResult;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.similarity.BuildScoreProvider;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.similarity.SearchScoreProvider;
import com.hazelcast.shaded.io.github.jbellis.jvector.util.AtomicFixedBitSet;
import com.hazelcast.shaded.io.github.jbellis.jvector.util.BitSet;
import com.hazelcast.shaded.io.github.jbellis.jvector.util.Bits;
import com.hazelcast.shaded.io.github.jbellis.jvector.util.ExceptionUtils;
import com.hazelcast.shaded.io.github.jbellis.jvector.util.ExplicitThreadLocal;
import com.hazelcast.shaded.io.github.jbellis.jvector.util.PhysicalCoreExecutor;
import com.hazelcast.shaded.io.github.jbellis.jvector.util.ThreadSafeGrowableBitSet;
import com.hazelcast.shaded.io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import com.hazelcast.shaded.io.github.jbellis.jvector.vector.VectorUtil;
import com.hazelcast.shaded.io.github.jbellis.jvector.vector.types.VectorFloat;
import com.hazelcast.shaded.org.agrona.collections.IntArrayList;
import com.hazelcast.shaded.org.agrona.collections.IntArrayQueue;
import java.io.Closeable;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.OptionalDouble;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GraphIndexBuilder
implements Closeable {
    private static final Logger logger = LoggerFactory.getLogger(GraphIndexBuilder.class);
    private final int beamWidth;
    private final ExplicitThreadLocal<NodeArray> naturalScratch;
    private final ExplicitThreadLocal<NodeArray> concurrentScratch;
    private final int dimension;
    private final float neighborOverflow;
    private final float alpha;
    @VisibleForTesting
    final OnHeapGraphIndex graph;
    private double averageShortEdges = Double.NaN;
    private final ConcurrentSkipListSet<Integer> insertionsInProgress = new ConcurrentSkipListSet();
    private final BuildScoreProvider scoreProvider;
    private final ForkJoinPool simdExecutor;
    private final ForkJoinPool parallelExecutor;
    private final ExplicitThreadLocal<GraphSearcher> searchers;
    private final AtomicInteger updateEntryNodeIn = new AtomicInteger(10000);

    public GraphIndexBuilder(RandomAccessVectorValues vectorValues, VectorSimilarityFunction similarityFunction, int M, int beamWidth, float neighborOverflow, float alpha) {
        this(BuildScoreProvider.randomAccessScoreProvider(vectorValues, similarityFunction), vectorValues.dimension(), M, beamWidth, neighborOverflow, alpha);
    }

    public GraphIndexBuilder(BuildScoreProvider scoreProvider, int dimension, int M, int beamWidth, float neighborOverflow, float alpha) {
        this(scoreProvider, dimension, M, beamWidth, neighborOverflow, alpha, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool());
    }

    public GraphIndexBuilder(BuildScoreProvider scoreProvider, int dimension, int M, int beamWidth, float neighborOverflow, float alpha, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) {
        if (M <= 0) {
            throw new IllegalArgumentException("maxConn must be positive");
        }
        if (beamWidth <= 0) {
            throw new IllegalArgumentException("beamWidth must be positive");
        }
        if (neighborOverflow < 1.0f) {
            throw new IllegalArgumentException("neighborOverflow must be >= 1.0");
        }
        if (alpha <= 0.0f) {
            throw new IllegalArgumentException("alpha must be positive");
        }
        this.scoreProvider = scoreProvider;
        this.dimension = dimension;
        this.neighborOverflow = neighborOverflow;
        this.alpha = alpha;
        this.beamWidth = beamWidth;
        this.simdExecutor = simdExecutor;
        this.parallelExecutor = parallelExecutor;
        int maxOverflowDegree = (int)((float)M * neighborOverflow);
        this.graph = new OnHeapGraphIndex(M, maxOverflowDegree, scoreProvider, alpha);
        this.searchers = ExplicitThreadLocal.withInitial(() -> new GraphSearcher(this.graph));
        this.naturalScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(Math.max(beamWidth, M + 1)));
        this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(Math.max(beamWidth, M + 1)));
    }

    public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvider newProvider) {
        GraphIndexBuilder newBuilder = new GraphIndexBuilder(newProvider, other.dimension, other.graph.maxDegree(), other.beamWidth, other.neighborOverflow, other.alpha, other.simdExecutor, other.parallelExecutor);
        for (int i = 0; i < other.graph.getIdUpperBound(); ++i) {
            if (!other.graph.containsNode(i)) continue;
            ConcurrentNeighborMap.Neighbors neighbors = other.graph.getNeighbors(i);
            ScoreFunction sf = newProvider.searchProviderFor(i).scoreFunction();
            NodeArray newNeighbors = new NodeArray(neighbors.size());
            NodesIterator it = neighbors.iterator();
            while (it.hasNext()) {
                int neighbor = it.nextInt();
                newNeighbors.insertSorted(neighbor, sf.similarityTo(neighbor));
            }
            newBuilder.graph.addNode(i, newNeighbors);
        }
        newBuilder.graph.updateEntryNode(other.graph.entry());
        return newBuilder;
    }

    public OnHeapGraphIndex build(RandomAccessVectorValues ravv) {
        Supplier<RandomAccessVectorValues> vv = ravv.threadLocalSupplier();
        int size = ravv.size();
        ((ForkJoinTask)this.simdExecutor.submit(() -> IntStream.range(0, size).parallel().forEach(arg_0 -> this.lambda$build$3((Supplier)vv, arg_0)))).join();
        this.cleanup();
        return this.graph;
    }

    public void cleanup() {
        if (this.graph.size() == 0) {
            return;
        }
        this.graph.validateEntryNode();
        this.removeDeletedNodes();
        if (this.graph.size() == 0) {
            return;
        }
        this.averageShortEdges = ((OptionalDouble)((ForkJoinTask)this.parallelExecutor.submit(() -> IntStream.range(0, this.graph.getIdUpperBound()).parallel().mapToDouble(this.graph.nodes::enforceDegree).filter(Double::isFinite).average())).join()).orElse(Double.NaN);
        this.updateEntryPoint();
        this.reconnectOrphanedNodes();
    }

    private void reconnectOrphanedNodes() {
        AtomicFixedBitSet globalConnectionTargets = new AtomicFixedBitSet(this.graph.getIdUpperBound());
        globalConnectionTargets.set(this.graph.entry());
        for (int i = 0; i < 5; ++i) {
            AtomicFixedBitSet connectedNodes = new AtomicFixedBitSet(this.graph.getIdUpperBound());
            ConcurrentNeighborMap.Neighbors entryNeighbors = this.graph.getNeighbors(this.graph.entry());
            ((ForkJoinTask)this.parallelExecutor.submit(() -> IntStream.range(0, entryNeighbors.size()).parallel().forEach(node -> this.findConnected(connectedNodes, entryNeighbors.getNode(node))))).join();
            AtomicInteger nReconnectAttempts = new AtomicInteger();
            AtomicInteger nReconnectedViaNeighbors = new AtomicInteger();
            AtomicInteger nResumesRun = new AtomicInteger();
            AtomicInteger nReconnectedViaSearch = new AtomicInteger();
            ((ForkJoinTask)this.simdExecutor.submit(() -> IntStream.range(0, this.graph.getIdUpperBound()).parallel().forEach(node -> {
                ConcurrentNeighborMap.Neighbors self;
                if (connectedNodes.get(node) || !this.graph.containsNode(node)) {
                    return;
                }
                nReconnectAttempts.incrementAndGet();
                NodeArray neighbors = self = this.graph.getNeighbors(node);
                if (this.connectToClosestNeighbor(node, neighbors, connectedNodes, globalConnectionTargets) != null) {
                    nReconnectedViaNeighbors.incrementAndGet();
                    return;
                }
                try (GraphSearcher gs = this.searchers.get();){
                    SearchScoreProvider ssp = this.scoreProvider.searchProviderFor(node);
                    int ep = this.graph.entry();
                    SearchResult result = gs.searchInternal(ssp, this.beamWidth, this.beamWidth, 0.0f, 0.0f, ep, other -> other != node);
                    neighbors = new NodeArray(result.getNodes().length);
                    GraphIndexBuilder.toScratchCandidates(result.getNodes(), neighbors);
                    SearchResult.NodeScore reconnectedTo = this.connectToClosestNeighbor(node, neighbors, Bits.ALL, globalConnectionTargets);
                    for (int j = 0; reconnectedTo == null && j < 2 * this.graph.maxDegree; ++j) {
                        nResumesRun.incrementAndGet();
                        result = gs.resume(this.beamWidth, this.beamWidth);
                        GraphIndexBuilder.toScratchCandidates(result.getNodes(), neighbors);
                        reconnectedTo = this.connectToClosestNeighbor(node, neighbors, Bits.ALL, globalConnectionTargets);
                    }
                    if (reconnectedTo != null) {
                        nReconnectedViaSearch.incrementAndGet();
                        NodeArray na = new NodeArray(1);
                        na.addInOrder(reconnectedTo.node, reconnectedTo.score);
                        this.graph.nodes.backlink(na, node, 1.0f);
                    }
                }
                catch (IOException e) {
                    throw new UncheckedIOException(e);
                }
            }))).join();
            logger.debug("Reconnecting {} nodes out of {} on pass {}. {} neighbor reconnects. {} searches/resumes run. {} nodes reconnected via search", new Object[]{nReconnectAttempts.get(), this.graph.size(), i, nReconnectedViaNeighbors.get(), nResumesRun.get(), nReconnectedViaSearch.get()});
            if (nReconnectAttempts.get() == 0) break;
        }
    }

    private SearchResult.NodeScore connectToClosestNeighbor(int node, NodeArray neighbors, Bits connectedNodes, BitSet connectionTargets) {
        for (int i = 0; i < neighbors.size(); ++i) {
            int neighborNode = neighbors.getNode(i);
            if (!connectedNodes.get(neighborNode) || connectionTargets.get(neighborNode)) continue;
            float neighborScore = neighbors.getScore(i);
            this.graph.nodes.insertEdgeNotDiverse(neighborNode, node, neighborScore);
            connectionTargets.set(neighborNode);
            return new SearchResult.NodeScore(neighborNode, neighborScore);
        }
        return null;
    }

    private void findConnected(AtomicFixedBitSet connectedNodes, int start) {
        IntArrayQueue queue = new IntArrayQueue();
        queue.add(start);
        try (OnHeapGraphIndex.ConcurrentGraphIndexView view = this.graph.getView();){
            while (!queue.isEmpty()) {
                int next = queue.pollInt();
                if (connectedNodes.getAndSet(next)) continue;
                NodesIterator it = view.getNeighborsIterator(next);
                while (it.hasNext()) {
                    queue.addInt(it.nextInt());
                }
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public OnHeapGraphIndex getGraph() {
        return this.graph;
    }

    public int insertsInProgress() {
        return this.insertionsInProgress.size();
    }

    @Deprecated
    public long addGraphNode(int node, RandomAccessVectorValues ravv) {
        return this.addGraphNode(node, ravv.getVector(node));
    }

    public long addGraphNode(int node, VectorFloat<?> vector) {
        this.graph.addNode(node);
        this.insertionsInProgress.add(node);
        Object inProgressBefore = this.insertionsInProgress.clone();
        try (GraphSearcher gs = this.searchers.get();){
            NodeArray naturalScratchPooled = this.naturalScratch.get();
            NodeArray concurrentScratchPooled = this.concurrentScratch.get();
            int ep = this.graph.entry();
            ExcludingBits bits = new ExcludingBits(node);
            SearchScoreProvider ssp = this.scoreProvider.searchProviderFor(vector);
            SearchResult result = gs.searchInternal(ssp, this.beamWidth, this.beamWidth, 0.0f, 0.0f, ep, bits);
            NodeArray natural = GraphIndexBuilder.toScratchCandidates(result.getNodes(), naturalScratchPooled);
            NodeArray concurrent = this.getConcurrentCandidates(node, (Set<Integer>)inProgressBefore, concurrentScratchPooled, ssp.scoreFunction());
            this.updateNeighbors(node, natural, concurrent);
            this.maybeUpdateEntryPoint(node);
            this.maybeImproveOlderNode();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        finally {
            this.insertionsInProgress.remove(node);
        }
        return this.graph.ramBytesUsedOneNode();
    }

    private void maybeImproveOlderNode() {
        if (this.dimension <= 3 && this.graph.size() > 20000) {
            for (int i = 0; i < 3; ++i) {
                int olderNode = ThreadLocalRandom.current().nextInt(this.graph.size());
                if (!this.graph.containsNode(olderNode) || this.graph.getDeletedNodes().get(olderNode)) continue;
                this.improveConnections(olderNode);
                break;
            }
        }
    }

    private void maybeUpdateEntryPoint(int node) {
        this.graph.maybeSetInitialEntryNode(node);
        if (this.updateEntryNodeIn.decrementAndGet() == 0) {
            this.updateEntryPoint();
        }
    }

    @VisibleForTesting
    public void setEntryPoint(int ep) {
        this.graph.updateEntryNode(ep);
    }

    private void updateEntryPoint() {
        int newEntryNode = this.approximateMedioid();
        this.graph.updateEntryNode(newEntryNode);
        if (newEntryNode >= 0) {
            this.improveConnections(newEntryNode);
            this.updateEntryNodeIn.addAndGet(this.graph.size());
        } else {
            this.updateEntryNodeIn.addAndGet(10000);
        }
    }

    private void improveConnections(int node) {
        SearchResult result;
        NodeArray naturalScratchPooled;
        try (GraphSearcher gs = this.searchers.get();){
            naturalScratchPooled = this.naturalScratch.get();
            int ep = this.graph.entry();
            ExcludingBits bits = new ExcludingBits(node);
            SearchScoreProvider ssp = this.scoreProvider.searchProviderFor(node);
            result = gs.searchInternal(ssp, this.beamWidth, this.beamWidth, 0.0f, 0.0f, ep, bits);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        NodeArray natural = GraphIndexBuilder.toScratchCandidates(result.getNodes(), naturalScratchPooled);
        ConcurrentNeighborMap.Neighbors neighbors = this.graph.nodes.insertDiverse(node, natural);
        this.graph.nodes.backlink(neighbors, node, 1.0f);
    }

    public void markNodeDeleted(int node) {
        this.graph.markDeleted(node);
    }

    public synchronized long removeDeletedNodes() {
        ThreadSafeGrowableBitSet toDelete = this.graph.getDeletedNodes().copy();
        int nRemoved = toDelete.cardinality();
        if (nRemoved == 0) {
            return 0L;
        }
        IntArrayList liveNodes = new IntArrayList();
        for (int i = 0; i < this.graph.getIdUpperBound(); ++i) {
            if (!this.graph.containsNode(i) || toDelete.get(i)) continue;
            liveNodes.add(i);
        }
        ConcurrentHashMap newEdges = new ConcurrentHashMap();
        ((ForkJoinTask)this.parallelExecutor.submit(() -> IntStream.range(0, this.graph.getIdUpperBound()).parallel().forEach(i -> {
            ConcurrentNeighborMap.Neighbors neighbors = this.graph.getNeighbors(i);
            if (neighbors == null || toDelete.get(i)) {
                return;
            }
            NodesIterator it = neighbors.iterator();
            while (it.hasNext()) {
                int j = it.nextInt();
                if (!toDelete.get(j)) continue;
                Set newEdgesForI = newEdges.computeIfAbsent(i, __ -> ConcurrentHashMap.newKeySet());
                NodesIterator jt = this.graph.getNeighbors(j).iterator();
                while (jt.hasNext()) {
                    int k = jt.nextInt();
                    if (i == k || toDelete.get(k)) continue;
                    newEdgesForI.add(k);
                }
            }
        }))).join();
        ((ForkJoinTask)this.simdExecutor.submit(() -> ((Stream)newEdges.entrySet().stream().parallel()).forEach(e -> {
            int node = (Integer)e.getKey();
            ScoreFunction sf = this.scoreProvider.searchProviderFor(node).scoreFunction();
            NodeArray candidates = new NodeArray(this.graph.maxDegree);
            for (Integer k : (Set)e.getValue()) {
                candidates.insertSorted(k, sf.similarityTo(k));
            }
            if (candidates.size() == 0) {
                ThreadLocalRandom R = ThreadLocalRandom.current();
                for (int i = 0; i < 2 * this.graph.maxDegree(); ++i) {
                    int randomNode = liveNodes.get(R.nextInt(liveNodes.size()));
                    if (randomNode != node && !candidates.contains(randomNode)) {
                        float score = sf.similarityTo(randomNode);
                        candidates.insertSorted(randomNode, score);
                    }
                    if (candidates.size() == this.graph.maxDegree) break;
                }
            }
            this.graph.nodes.replaceDeletedNeighbors(node, toDelete, candidates);
        }))).join();
        if (toDelete.get(this.graph.entry())) {
            this.updateEntryPoint();
        }
        assert (toDelete.cardinality() == nRemoved) : "cardinality changed";
        int i = toDelete.nextSetBit(0);
        while (i != Integer.MAX_VALUE) {
            this.graph.removeNode(i);
            i = toDelete.nextSetBit(i + 1);
        }
        return (long)nRemoved * this.graph.ramBytesUsedOneNode();
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private int approximateMedioid() {
        if (this.graph.size() == 0) {
            return -1;
        }
        VectorFloat<?> centroid = this.scoreProvider.approximateCentroid();
        if ((double)VectorUtil.dotProduct(centroid, centroid) < 1.0E-6) {
            return this.randomLiveNode();
        }
        int ep = this.graph.entry();
        SearchScoreProvider ssp = this.scoreProvider.searchProviderFor(centroid);
        try (GraphSearcher gs = this.searchers.get();){
            SearchResult result = gs.searchInternal(ssp, this.beamWidth, this.beamWidth, 0.0f, 0.0f, ep, Bits.ALL);
            if (result.getNodes().length != 0) {
                int n2 = result.getNodes()[0].node;
                return n2;
            }
            int n = this.randomLiveNode();
            return n;
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    private void updateNeighbors(int nodeId, NodeArray natural, NodeArray concurrent) {
        NodeArray toMerge = concurrent.size() == 0 ? natural : (natural.size() == 0 ? concurrent : NodeArray.merge(natural, concurrent));
        ConcurrentNeighborMap.Neighbors neighbors = this.graph.nodes.insertDiverse(nodeId, toMerge);
        this.graph.nodes.backlink(neighbors, nodeId, this.neighborOverflow);
    }

    private static NodeArray toScratchCandidates(SearchResult.NodeScore[] candidates, NodeArray scratch) {
        scratch.clear();
        for (SearchResult.NodeScore candidate : candidates) {
            scratch.addInOrder(candidate.node, candidate.score);
        }
        return scratch;
    }

    private NodeArray getConcurrentCandidates(int newNode, Set<Integer> inProgress, NodeArray scratch, ScoreFunction scoreFunction) {
        scratch.clear();
        for (Integer n : inProgress) {
            if (n == newNode) continue;
            scratch.insertSorted(n, scoreFunction.similarityTo(n));
        }
        return scratch;
    }

    @Override
    public void close() throws IOException {
        try {
            this.searchers.close();
        }
        catch (Exception e) {
            ExceptionUtils.throwIoException(e);
        }
    }

    @VisibleForTesting
    int randomLiveNode() {
        ThreadLocalRandom R = ThreadLocalRandom.current();
        for (int i = 0; i < 3; ++i) {
            int idUpperBound = this.graph.getIdUpperBound();
            if (idUpperBound == 0) {
                return -1;
            }
            int n = R.nextInt(idUpperBound);
            if (!this.graph.containsNode(n) || this.graph.getDeletedNodes().get(n)) continue;
            return n;
        }
        ArrayList<Integer> L = new ArrayList<Integer>();
        for (int i = 0; i < this.graph.getIdUpperBound(); ++i) {
            if (!this.graph.containsNode(i) || this.graph.getDeletedNodes().get(i)) continue;
            L.add(i);
        }
        if (L.isEmpty()) {
            return -1;
        }
        return (Integer)L.get(R.nextInt(L.size()));
    }

    @VisibleForTesting
    void validateAllNodesLive() {
        assert (this.graph.getDeletedNodes().cardinality() == 0);
        for (int i = 0; i < this.graph.getIdUpperBound(); ++i) {
            if (!this.graph.containsNode(i)) continue;
            ConcurrentNeighborMap.Neighbors neighbors = this.graph.getNeighbors(i);
            NodesIterator it = neighbors.iterator();
            while (it.hasNext()) {
                int j = it.nextInt();
                assert (this.graph.containsNode(j)) : String.format("Edge %d -> %d is invalid", i, j);
            }
        }
    }

    public double getAverageShortEdges() {
        return this.averageShortEdges;
    }

    public void load(RandomAccessReader in) throws IOException {
        if (this.graph.size() != 0) {
            throw new IllegalStateException("Cannot load into a non-empty graph");
        }
        int size = in.readInt();
        int entryNode = in.readInt();
        int maxDegree = in.readInt();
        for (int i = 0; i < size; ++i) {
            int nodeId = in.readInt();
            int nNeighbors = in.readInt();
            ScoreFunction.ExactScoreFunction sf = this.scoreProvider.searchProviderFor(nodeId).exactScoreFunction();
            NodeArray ca = new NodeArray(nNeighbors);
            for (int j = 0; j < nNeighbors; ++j) {
                int neighbor = in.readInt();
                ca.addInOrder(neighbor, sf.similarityTo(neighbor));
            }
            this.graph.addNode(nodeId, ca);
        }
        this.graph.updateEntryNode(entryNode);
    }

    private /* synthetic */ void lambda$build$3(Supplier vv, int node) {
        this.addGraphNode(node, ((RandomAccessVectorValues)vv.get()).getVector(node));
    }

    private static class ExcludingBits
    implements Bits {
        private final int excluded;

        public ExcludingBits(int excluded) {
            this.excluded = excluded;
        }

        @Override
        public boolean get(int index) {
            return index != this.excluded;
        }
    }
}

