/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.vector.impl.storage;

import com.hazelcast.config.InMemoryFormat;
import com.hazelcast.config.vector.VectorCollectionConfig;
import com.hazelcast.config.vector.VectorIndexConfig;
import com.hazelcast.internal.memory.Measurable;
import com.hazelcast.internal.serialization.Data;
import com.hazelcast.internal.util.JVMUtil;
import com.hazelcast.logging.ILogger;
import com.hazelcast.map.impl.record.Record;
import com.hazelcast.map.impl.record.SimpleRecord;
import com.hazelcast.map.impl.recordstore.Storage;
import com.hazelcast.map.impl.recordstore.StorageImpl;
import com.hazelcast.map.impl.recordstore.expiry.ExpirySystem;
import com.hazelcast.spi.impl.NodeEngine;
import com.hazelcast.vector.SearchOptions;
import com.hazelcast.vector.SearchResult;
import com.hazelcast.vector.SearchResults;
import com.hazelcast.vector.VectorDocument;
import com.hazelcast.vector.VectorValues;
import com.hazelcast.vector.impl.DataVectorDocument;
import com.hazelcast.vector.impl.Hints;
import com.hazelcast.vector.impl.InternalSearchResult;
import com.hazelcast.vector.impl.storage.AbstractVectorIndex;
import com.hazelcast.vector.impl.storage.ConcurrentModificationPolicy;
import com.hazelcast.vector.impl.storage.ReplicationStateHolder;
import com.hazelcast.vector.impl.storage.VectorIndexFactory;
import java.util.ConcurrentModificationException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

public class VectorCollectionStorage
implements Measurable {
    private static final long FIXED_HEAP_BYTES_USED = (long)JVMUtil.OBJECT_HEADER_SIZE + 6L * (long)JVMUtil.REFERENCE_COST_IN_BYTES + 4L + 1L;
    private static final int MAX_CONCURRENT_MODIFICATION_RETRIES = 3;
    private final NodeEngine nodeEngine;
    private final ILogger logger;
    private final VectorCollectionConfig config;
    private final String name;
    private final Storage<Data, Record<Data>> recordStore;
    private final int partitionId;
    private final boolean binaryMetadataFormat;
    private final VectorIndexHolder vectorIndexes;

    public VectorCollectionStorage(NodeEngine nodeEngine, String name, int partitionId, VectorCollectionConfig config) {
        this.nodeEngine = nodeEngine;
        this.logger = nodeEngine.getLogger(VectorCollectionStorage.class);
        this.config = config;
        this.name = name;
        this.partitionId = partitionId;
        this.recordStore = new StorageImpl<Record<Data>>(InMemoryFormat.BINARY, ExpirySystem.NULL, nodeEngine.getSerializationService());
        this.binaryMetadataFormat = true;
        this.vectorIndexes = new VectorIndexHolder(config.getVectorIndexConfigs());
    }

    public String getName() {
        return this.name;
    }

    public VectorCollectionConfig getConfig() {
        return this.config;
    }

    public int getPartitionId() {
        return this.partitionId;
    }

    public VectorDocument<Data> put(Data keyData, Data userValue, VectorValues vectorValues) {
        this.checkMutatingOperationAllowed();
        this.validateVectors(vectorValues);
        DataVectorDocument oldDocument = null;
        Record<Data> oldRecord = this.recordStore.get(keyData);
        Data oldValue = null;
        if (oldRecord == null) {
            this.recordStore.put(keyData, VectorCollectionStorage.createRecord(userValue));
        } else {
            oldValue = oldRecord.getValue();
            this.recordStore.updateRecordValue(keyData, oldRecord, userValue);
        }
        VectorValues oldVectors = this.putVectorsReturningPrevious(keyData, vectorValues);
        assert (oldRecord == null && oldVectors == null || oldRecord != null && oldVectors != null) : "Record store inconsistent with vector index";
        if (oldRecord != null) {
            oldDocument = new DataVectorDocument(oldValue, oldVectors);
        }
        return oldDocument;
    }

    public VectorDocument<Data> putIfAbsent(Data keyData, Data userValue, VectorValues vectorValues) {
        this.checkMutatingOperationAllowed();
        this.validateVectors(vectorValues);
        VectorDocument<Data> oldValue = this.get(keyData);
        if (oldValue != null) {
            return oldValue;
        }
        this.recordStore.put(keyData, VectorCollectionStorage.createRecord(userValue));
        this.putVectors(keyData, vectorValues);
        return null;
    }

    public void set(Data keyData, Data userValue, VectorValues vectorValues) {
        this.checkMutatingOperationAllowed();
        this.validateVectors(vectorValues);
        this.recordStore.put(keyData, VectorCollectionStorage.createRecord(userValue));
        this.putVectors(keyData, vectorValues);
    }

    private static SimpleRecord<Data> createRecord(Data userValue) {
        return new SimpleRecord<Data>(userValue);
    }

    private void validateVectors(VectorValues vectorValues) {
        if (vectorValues instanceof VectorValues.SingleVectorValues) {
            if (this.vectorIndexes.isMultiIndex()) {
                throw new IllegalArgumentException("Collection has " + this.vectorIndexes.getSize() + " indexes, cannot put a single vector");
            }
        } else if (vectorValues instanceof VectorValues.MultiIndexVectorValues) {
            VectorValues.MultiIndexVectorValues mvv = (VectorValues.MultiIndexVectorValues)vectorValues;
            if (this.vectorIndexes.getSize() != mvv.indexNameToVector().size()) {
                throw new IllegalArgumentException("Collection has " + this.vectorIndexes.getSize() + " indexes, cannot put " + mvv.indexNameToVector().size() + " vectors");
            }
            List<String> notExistsIndexes = mvv.indexNameToVector().keySet().stream().filter(index -> !this.vectorIndexes.doesIndexExist((String)index)).toList();
            if (!notExistsIndexes.isEmpty()) {
                throw new IllegalArgumentException("Invalid vector names specified, the collection does not contain the requested indexes: " + String.valueOf(notExistsIndexes));
            }
        }
    }

    private void putVectors(Data keyData, VectorValues vectorValues) {
        if (vectorValues instanceof VectorValues.SingleVectorValues) {
            VectorValues.SingleVectorValues svv = (VectorValues.SingleVectorValues)vectorValues;
            this.vectorIndexes.getSingleIndex().put(keyData, svv.vector());
        } else if (vectorValues instanceof VectorValues.MultiIndexVectorValues) {
            VectorValues.MultiIndexVectorValues mvv = (VectorValues.MultiIndexVectorValues)vectorValues;
            mvv.indexNameToVector().forEach((key, value) -> this.vectorIndexes.getIndex((String)key).put(keyData, (float[])value));
        } else {
            throw new UnsupportedOperationException("Unsupported VectorValues type");
        }
    }

    @Nullable
    private VectorValues putVectorsReturningPrevious(Data keyData, VectorValues vectorValues) {
        if (vectorValues instanceof VectorValues.SingleVectorValues) {
            VectorValues.SingleVectorValues svv = (VectorValues.SingleVectorValues)vectorValues;
            float[] previous = this.vectorIndexes.getSingleIndex().put(keyData, svv.vector(), true);
            return previous != null ? this.onlyVector(previous) : null;
        }
        if (vectorValues instanceof VectorValues.MultiIndexVectorValues) {
            VectorValues.MultiIndexVectorValues mvv = (VectorValues.MultiIndexVectorValues)vectorValues;
            Map<String, float[]> previousVectors = mvv.indexNameToVector().entrySet().stream().map(entry -> {
                float[] previous = this.vectorIndexes.getIndex((String)entry.getKey()).put(keyData, (float[])entry.getValue(), true);
                return previous != null ? Map.entry((String)entry.getKey(), previous) : null;
            }).filter(Objects::nonNull).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
            assert (previousVectors.isEmpty() || previousVectors.size() == this.vectorIndexes.getSize()) : "Inconsistent vector indexes, key was present only in " + previousVectors.size();
            return !previousVectors.isEmpty() ? VectorValues.of(previousVectors) : null;
        }
        throw new UnsupportedOperationException("Unsupported VectorValues type");
    }

    public VectorDocument<Data> get(Data key) {
        VectorValues vectorValues;
        Data value = this.getInternal(key);
        if (value == null) {
            return null;
        }
        if (!this.vectorIndexes.isMultiIndex()) {
            vectorValues = this.onlyVector(this.vectorIndexes.getSingleIndex().get(key));
        } else {
            HashMap<String, float[]> vectors = new HashMap<String, float[]>(this.vectorIndexes.getSize());
            this.vectorIndexes.forEachIndex(index -> vectors.put(index.indexName, index.get(key)));
            vectorValues = VectorValues.of(vectors);
        }
        return new DataVectorDocument(value, vectorValues);
    }

    public void delete(Data key) {
        this.checkMutatingOperationAllowed();
        Record<Data> oldRecord = this.recordStore.get(key);
        if (oldRecord != null) {
            this.recordStore.removeRecord(key, oldRecord);
            this.vectorIndexes.forEachIndex(vectorIndex -> {
                boolean deleted = vectorIndex.delete(key);
                assert (deleted) : "Inconsistent vector indexes, key was not present in " + vectorIndex.indexName;
            });
        }
    }

    public VectorDocument<Data> remove(Data key) {
        this.checkMutatingOperationAllowed();
        VectorDocument<Data> oldValue = this.get(key);
        if (oldValue != null) {
            this.delete(key);
        }
        return oldValue;
    }

    public void optimize(String indexName) {
        this.vectorIndexes.validateAndGetIndex(indexName).cleanup();
    }

    private Data getInternal(Data key) {
        Record<Data> record = this.recordStore.get(key);
        return record != null ? record.getValue() : null;
    }

    private VectorValues onlyVector(float[] value) {
        String name = this.vectorIndexes.getSingleIndex().indexName;
        return name != null ? VectorValues.of(name, value) : VectorValues.of(value);
    }

    public SearchResults<Data, Data> search(VectorValues vectors, SearchOptions searchOptions) {
        for (int attempt = 0; attempt < 2; ++attempt) {
            try {
                return this.doSearch(vectors, searchOptions, ConcurrentModificationPolicy.THROW);
            }
            catch (ConcurrentModificationException cme) {
                this.logger.fine("Search attempt " + attempt + " failed", cme);
                continue;
            }
        }
        return this.doSearch(vectors, searchOptions, ConcurrentModificationPolicy.SKIP);
    }

    private SearchResults<Data, Data> doSearch(VectorValues vectors, SearchOptions searchOptions, ConcurrentModificationPolicy cmPolicy) {
        Iterator<SearchResult<Data, Data>> it;
        VectorValues.MultiIndexVectorValues multiIndexVectorValues;
        SearchResults<Data, Data> results;
        AbstractVectorIndex index;
        int limit = this.getPartitionLimit(searchOptions);
        if (vectors instanceof VectorValues.SingleVectorValues) {
            VectorValues.SingleVectorValues singleVectorValues = (VectorValues.SingleVectorValues)vectors;
            if (this.vectorIndexes.isMultiIndex()) {
                throw new IllegalArgumentException("Index must be selected for collection with more than 1 index");
            }
            index = this.vectorIndexes.getSingleIndex();
            results = index.search(singleVectorValues.vector(), limit, cmPolicy);
        } else if (vectors instanceof VectorValues.MultiIndexVectorValues && (multiIndexVectorValues = (VectorValues.MultiIndexVectorValues)vectors).indexNameToVector().size() == 1) {
            Map.Entry<String, float[]> entry = multiIndexVectorValues.indexNameToVector().entrySet().iterator().next();
            String indexName = entry.getKey();
            index = this.vectorIndexes.getIndex(indexName);
            if (index == null) {
                throw new IllegalArgumentException("No vector index named '" + indexName + "' is defined");
            }
            results = index.search(entry.getValue(), limit, cmPolicy);
        } else {
            throw new UnsupportedOperationException("Cannot search multiple vector indexes");
        }
        if (searchOptions.isIncludeValue()) {
            it = results.results();
            while (it.hasNext()) {
                this.fillValue(it, cmPolicy);
            }
        }
        if (searchOptions.isIncludeVectors()) {
            it = results.results();
            while (it.hasNext()) {
                VectorCollectionStorage.fillVector(index, it, cmPolicy);
            }
        }
        return results;
    }

    public void lockIndexMutation(String indexName) {
        this.vectorIndexes.validateAndGetIndex(indexName).lockIndexMutation();
    }

    public void unlockIndexMutation(String indexName) {
        this.vectorIndexes.validateAndGetIndex(indexName).unlockIndexMutation();
    }

    private void checkMutatingOperationAllowed() {
        this.vectorIndexes.forEachIndex(AbstractVectorIndex::checkMutatingOperationAllowed);
    }

    public void clear() {
        this.checkMutatingOperationAllowed();
        this.vectorIndexes.clear();
        this.recordStore.clear(false);
    }

    public long size() {
        return this.recordStore.size();
    }

    private int getPartitionLimit(SearchOptions searchOptions) {
        int resultLimit = searchOptions.getLimit();
        Integer maybePartitionLimit = Hints.PARTITION_LIMIT.get(searchOptions);
        if (maybePartitionLimit == null) {
            return resultLimit;
        }
        if (maybePartitionLimit < 0) {
            throw new IllegalArgumentException("Partition limit cannot be negative");
        }
        if (maybePartitionLimit > resultLimit) {
            throw new IllegalArgumentException("Number of neighbours requested from partition is greater than requested result size");
        }
        if (maybePartitionLimit * this.nodeEngine.getPartitionService().getPartitionCount() < resultLimit) {
            throw new IllegalArgumentException("Number of neighbours requested from partition is not sufficient to generate full requested result");
        }
        return maybePartitionLimit;
    }

    private void fillValue(Iterator<? extends SearchResult<Data, Data>> it, ConcurrentModificationPolicy cmPolicy) {
        InternalSearchResult result = (InternalSearchResult)it.next();
        Data value = this.getInternal((Data)result.getKey());
        if (value == null) {
            cmPolicy.action(it);
        } else {
            result.setValue(value);
        }
    }

    private static void fillVector(AbstractVectorIndex index, Iterator<? extends SearchResult<?, ?>> it, ConcurrentModificationPolicy cmPolicy) {
        InternalSearchResult result = (InternalSearchResult)it.next();
        float[] vector = index.vectorsSupplier.vectorValue(result.id());
        if (vector == null) {
            cmPolicy.action(it);
        } else {
            result.setVectors(VectorValues.of(vector));
        }
    }

    VectorIndexHolder getVectorIndexes() {
        return this.vectorIndexes;
    }

    Storage<Data, Record<Data>> getRecordStore() {
        return this.recordStore;
    }

    public void prepareForMigration() {
        this.vectorIndexes.forEachIndex(index -> index.prepareForMigration(this, this.nodeEngine));
    }

    void resetState(ReplicationStateHolder.CollectionReplicationStateHolder state) {
        if (!state.config.equals(this.config)) {
            throw new IllegalArgumentException("Incoming vector collection partition has different configuration");
        }
        if (state.indexes.size() != this.vectorIndexes.getSize()) {
            throw new IllegalStateException("Number of configured indexes (" + this.vectorIndexes.getSize() + ") is not equal to number of received indexes (" + state.indexes.size() + ")");
        }
        assert (this.recordStore.isEmpty()) : "Migration to non-empty storage";
        state.entries.forEach((key, value) -> this.recordStore.put((Data)key, (Record<Data>)VectorCollectionStorage.createRecord(value)));
        state.indexes.forEach(indexState -> this.vectorIndexes.validateAndGetIndex(indexState.getName()).resetState((ReplicationStateHolder.CollectionReplicationStateHolder.IndexReplicationStateHolder)indexState));
    }

    @Override
    public long heapBytesUsed() {
        return FIXED_HEAP_BYTES_USED + this.vectorIndexes.heapBytesUsed() + this.recordStore.getEntryCostEstimator().getEstimate();
    }

    static class VectorIndexHolder
    implements Measurable {
        private static final long FIXED_HEAP_BYTES_USED = (long)JVMUtil.OBJECT_HEADER_SIZE + 3L * (long)JVMUtil.REFERENCE_COST_IN_BYTES + 4L;
        private final Map<String, AbstractVectorIndex> vectorIndexMap;
        private final int size;
        private final String singleIndexName;
        private final List<VectorIndexConfig> indexConfigs;

        private VectorIndexHolder(List<VectorIndexConfig> indexConfigs) {
            this.indexConfigs = indexConfigs;
            this.size = indexConfigs.size();
            this.singleIndexName = this.size == 1 ? indexConfigs.get(0).getName() : null;
            this.vectorIndexMap = new HashMap<String, AbstractVectorIndex>();
            this.initIndexMap(indexConfigs);
        }

        private void initIndexMap(List<VectorIndexConfig> indexConfigs) {
            for (VectorIndexConfig indexConfig : indexConfigs) {
                this.vectorIndexMap.put(indexConfig.getName(), VectorIndexFactory.create(indexConfig));
            }
        }

        public int getSize() {
            return this.size;
        }

        public boolean isMultiIndex() {
            return this.size > 1;
        }

        public boolean doesIndexExist(String indexName) {
            return this.vectorIndexMap.containsKey(indexName);
        }

        public void forEachIndex(Consumer<AbstractVectorIndex> action) {
            this.vectorIndexMap.values().forEach(action);
        }

        public AbstractVectorIndex getIndex(String indexName) {
            assert (indexName != null) : "the index name has not been specified.";
            return this.vectorIndexMap.get(indexName);
        }

        public AbstractVectorIndex getSingleIndex() {
            assert (!this.isMultiIndex()) : "not permitted for use with a multi-index holder.";
            return this.vectorIndexMap.get(this.singleIndexName);
        }

        public AbstractVectorIndex validateAndGetIndex(String indexName) {
            if (this.isMultiIndex()) {
                if (indexName == null) {
                    throw new IllegalArgumentException("The index name has not been specified.");
                }
                AbstractVectorIndex vectorIndex = this.vectorIndexMap.get(indexName);
                if (vectorIndex == null) {
                    throw new IllegalArgumentException("No index was found with the name: " + indexName);
                }
                return vectorIndex;
            }
            if (indexName != null && !indexName.equals(this.singleIndexName)) {
                throw new IllegalArgumentException("No index was found with the name: " + indexName);
            }
            return this.getSingleIndex();
        }

        public void clear() {
            this.initIndexMap(this.indexConfigs);
        }

        @Override
        public long heapBytesUsed() {
            long heapBytesUsed = FIXED_HEAP_BYTES_USED + (long)(this.vectorIndexMap.size() * JVMUtil.REFERENCE_COST_IN_BYTES);
            for (AbstractVectorIndex index : this.vectorIndexMap.values()) {
                heapBytesUsed += index.heapBytesUsed();
            }
            return heapBytesUsed;
        }
    }
}

