/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.vector.impl.storage;

import com.hazelcast.config.vector.Metric;
import com.hazelcast.internal.serialization.Data;
import com.hazelcast.internal.util.JVMUtil;
import com.hazelcast.nio.ObjectDataOutput;
import com.hazelcast.shaded.io.github.jbellis.jvector.graph.SearchResult;
import com.hazelcast.shaded.io.github.jbellis.jvector.vector.types.VectorFloat;
import com.hazelcast.shaded.org.jctools.maps.NonBlockingHashMapLong;
import com.hazelcast.vector.SearchResults;
import com.hazelcast.vector.impl.DataSearchResult;
import com.hazelcast.vector.impl.SearchResultsImpl;
import com.hazelcast.vector.impl.storage.AbstractVectorIndex;
import com.hazelcast.vector.impl.storage.ConcurrentModificationPolicy;
import com.hazelcast.vector.impl.storage.ReplicationStateHolder;
import com.hazelcast.vector.impl.storage.VectorKeysEntry;
import com.hazelcast.vector.impl.storage.VectorOneKeyEntry;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentSkipListMap;

public class VectorIndexMultipleKeys
extends AbstractVectorIndex {
    static final long FIXED_HEAP_BYTES_USED = 3L * (long)JVMUtil.REFERENCE_COST_IN_BYTES;
    static final long HEAP_BYTES_USED_FOR_KEY_TO_NODEID_MAP_ENTRY = (long)JVMUtil.REFERENCE_COST_IN_BYTES + (long)JVMUtil.OBJECT_HEADER_SIZE + 4L;
    static final long HEAP_BYTES_USED_FOR_VECTOR_TO_NODEID_MAP_ENTRY = (long)JVMUtil.REFERENCE_COST_IN_BYTES + (long)JVMUtil.OBJECT_HEADER_SIZE + 4L;
    static final long HEAP_BYTES_USED_PER_NODE = HEAP_BYTES_USED_FOR_VECTOR_TO_NODEID_MAP_ENTRY + HEAP_BYTES_USED_FOR_KEY_TO_NODEID_MAP_ENTRY;
    static final long HEAP_BYTES_PER_ONE_NODE_KEY_MAPPING = (long)(8 + JVMUtil.OBJECT_HEADER_SIZE) + 2L * (long)JVMUtil.REFERENCE_COST_IN_BYTES;
    static final long HEAP_BYTES_PER_ADDITIONAL_NODE_KEY_MAPPING = JVMUtil.REFERENCE_COST_IN_BYTES;
    private final NonBlockingHashMapLong<VectorKeysEntry> nodeIdToKeyMap = new NonBlockingHashMapLong(1024);
    private Map<Data, Integer> keyToNodeIdMap = new HashMap<Data, Integer>(1024);
    private final Map<float[], Integer> vectorToNodeIdMap = new ConcurrentSkipListMap<float[], Integer>(Arrays::compare);

    public VectorIndexMultipleKeys(String indexName, Metric metric, int maxDegree, int efConstruction, int dimensions) {
        super(indexName, metric, maxDegree, efConstruction, dimensions);
    }

    @Override
    protected VectorFloat<?> putInternal(Data key, VectorFloat<?> vector) {
        float[] vectorArray = (float[])vector.get();
        Integer currentNodeId = this.vectorToNodeIdMap.get(vectorArray);
        if (currentNodeId == null) {
            UpdateKeyMappingResult updateResult = this.updateKeyMapping(key, this.idGenerator++);
            int nodeId = updateResult.nodeId();
            this.vectorToNodeIdMap.put(vectorArray, nodeId);
            this.vectorsSupplier.add(nodeId, vector);
            this.indexBuilder.addGraphNode(nodeId, vector);
            return updateResult.previousVector();
        }
        UpdateKeyMappingResult updateResult = this.updateKeyMapping(key, currentNodeId);
        assert (updateResult.nodeId() == currentNodeId.intValue()) : "Should not change node id if the node already exists";
        return updateResult.previousVector();
    }

    @Override
    protected SearchResults<Data, Data> toDataSearchResults(SearchResult searchResult, int limit, ConcurrentModificationPolicy cmPolicy) {
        ArrayList resultsList = new ArrayList();
        for (SearchResult.NodeScore nodeScore : searchResult.getNodes()) {
            if (resultsList.size() >= limit) break;
            VectorKeysEntry keyData = this.nodeIdToKeyMap.get(nodeScore.node);
            if (keyData != null) {
                int addedCount = keyData.forEachKeyWithLimit(limit - resultsList.size(), key -> resultsList.add(new DataSearchResult(nodeScore.node, (Data)key, nodeScore.score)));
                if (addedCount != 0) continue;
                cmPolicy.action();
                continue;
            }
            cmPolicy.action();
        }
        return new SearchResultsImpl<Data, Data>(resultsList);
    }

    @Override
    protected Integer getNodeIdByKey(Data key) {
        return this.keyToNodeIdMap.get(key);
    }

    @Override
    protected boolean hasLiveNodes() {
        return !this.nodeIdToKeyMap.isEmpty();
    }

    private void deleteNode(int nodeId) {
        float[] vectorArray = (float[])this.vectorsSupplier.getVector(nodeId).get();
        this.vectorToNodeIdMap.remove(vectorArray);
        this.indexBuilder.markNodeDeleted(nodeId);
        this.nodeIdToKeyMap.remove(nodeId);
    }

    @Override
    protected boolean deleteInternal(Data key) {
        Integer nodeId = this.keyToNodeIdMap.remove(key);
        if (nodeId == null) {
            return false;
        }
        this.deleteReferenceToKeyFromNodeId(key, nodeId);
        this.maybeRecreateIndex();
        return true;
    }

    private void deleteReferenceToKeyFromNodeId(Data key, int nodeId) {
        VectorKeysEntry indexToKeys = this.nodeIdToKeyMap.get(nodeId);
        if (indexToKeys == null) {
            return;
        }
        indexToKeys.deleteKey(key);
        if (indexToKeys.isEmptyKeys()) {
            this.deleteNode(nodeId);
        }
    }

    private UpdateKeyMappingResult updateKeyMapping(Data key, int newNodeId) {
        Integer previousNodeId = this.keyToNodeIdMap.put(key, newNodeId);
        if (previousNodeId != null && previousNodeId == newNodeId) {
            return new UpdateKeyMappingResult(newNodeId, this.vectorsSupplier.getVector(previousNodeId));
        }
        VectorFloat<?> previousVector = null;
        if (previousNodeId != null) {
            previousVector = this.vectorsSupplier.getVector(previousNodeId);
            this.deleteReferenceToKeyFromNodeId(key, previousNodeId);
            if (this.maybeRecreateIndex()) {
                int smallerNewNodeId = this.idGenerator++;
                Integer prevNodeId = this.keyToNodeIdMap.put(key, smallerNewNodeId);
                assert (prevNodeId != null && prevNodeId == newNodeId) : "Mapping changed concurrently during put operation";
                newNodeId = smallerNewNodeId;
            }
        }
        this.addKeyMapping(key, newNodeId);
        return new UpdateKeyMappingResult(newNodeId, previousVector);
    }

    private void addKeyMapping(Data key, int newNodeId) {
        VectorIndexMultipleKeys.computeNonThreadSafe(this.nodeIdToKeyMap, newNodeId, (k, v) -> v == null ? new VectorOneKeyEntry(key) : v.addKey(key));
    }

    private static <V> V computeNonThreadSafe(NonBlockingHashMapLong<V> map, long key, LongObjBiFunction<V> remappingFunction) {
        Objects.requireNonNull(remappingFunction);
        V oldValue = map.get(key);
        V newValue = remappingFunction.apply(key, oldValue);
        if (newValue == null) {
            if (oldValue != null) {
                map.remove(key);
            }
        } else if (newValue != oldValue) {
            map.put(key, newValue);
        }
        return newValue;
    }

    @Override
    public long heapBytesUsed() {
        return FIXED_HEAP_BYTES_USED + super.heapBytesUsed() + (long)this.vectorToNodeIdMap.size() * HEAP_BYTES_USED_FOR_VECTOR_TO_NODEID_MAP_ENTRY + (long)this.keyToNodeIdMap.size() * HEAP_BYTES_USED_FOR_KEY_TO_NODEID_MAP_ENTRY + (long)this.nodeIdToKeyMap.size() * HEAP_BYTES_PER_ONE_NODE_KEY_MAPPING + (long)(this.keyToNodeIdMap.size() - this.nodeIdToKeyMap.size()) * HEAP_BYTES_PER_ADDITIONAL_NODE_KEY_MAPPING;
    }

    @Override
    void writeKeyToNodeIdMapping(ObjectDataOutput out, Data key) throws IOException {
        out.writeInt(this.keyToNodeIdMap.get(key));
    }

    @Override
    void resetState(ReplicationStateHolder.CollectionReplicationStateHolder.IndexReplicationStateHolder indexState) {
        super.resetState(indexState);
        assert (this.keyToNodeIdMap.isEmpty()) : "Migration to non empty index";
        this.keyToNodeIdMap = indexState.keyToNodeIdMap;
        assert (this.nodeIdToKeyMap.isEmpty()) : "Migration to non empty index";
        this.keyToNodeIdMap.forEach(this::addKeyMapping);
        assert (this.vectorToNodeIdMap.isEmpty()) : "Migration to non empty index";
        this.vectorsSupplier.entries().forEach(entry -> this.vectorToNodeIdMap.put((float[])((VectorFloat)entry.getValue()).get(), ((Long)entry.getKey()).intValue()));
    }

    private record UpdateKeyMappingResult(int nodeId, VectorFloat<?> previousVector) {
    }

    private static interface LongObjBiFunction<V> {
        public V apply(long var1, V var3);
    }
}

