/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.vector.impl.proxy;

import com.hazelcast.cluster.Address;
import com.hazelcast.config.vector.VectorCollectionConfig;
import com.hazelcast.internal.partition.IPartitionService;
import com.hazelcast.internal.serialization.Data;
import com.hazelcast.internal.serialization.SerializationService;
import com.hazelcast.internal.util.CollectionUtil;
import com.hazelcast.internal.util.ConcurrencyUtil;
import com.hazelcast.internal.util.ExceptionUtil;
import com.hazelcast.internal.util.MutableLong;
import com.hazelcast.internal.util.Preconditions;
import com.hazelcast.internal.util.Timer;
import com.hazelcast.spi.impl.AbstractDistributedObject;
import com.hazelcast.spi.impl.InternalCompletableFuture;
import com.hazelcast.spi.impl.NodeEngine;
import com.hazelcast.spi.impl.operationservice.Operation;
import com.hazelcast.spi.impl.operationservice.OperationFactory;
import com.hazelcast.vector.SearchOptions;
import com.hazelcast.vector.SearchResults;
import com.hazelcast.vector.VectorCollection;
import com.hazelcast.vector.VectorDocument;
import com.hazelcast.vector.VectorValues;
import com.hazelcast.vector.impl.VectorCollectionService;
import com.hazelcast.vector.impl.VectorUtil;
import com.hazelcast.vector.impl.ops.ClearOperationsFactory;
import com.hazelcast.vector.impl.ops.DeleteOperation;
import com.hazelcast.vector.impl.ops.GetOperation;
import com.hazelcast.vector.impl.ops.OptimizeOperationsFactory;
import com.hazelcast.vector.impl.ops.PutAllOperationFactory;
import com.hazelcast.vector.impl.ops.PutIfAbsentOperation;
import com.hazelcast.vector.impl.ops.PutOperation;
import com.hazelcast.vector.impl.ops.RemoveOperation;
import com.hazelcast.vector.impl.ops.SetOperation;
import com.hazelcast.vector.impl.ops.SizeOperationsFactory;
import com.hazelcast.vector.impl.ops.VectorEntries;
import com.hazelcast.vector.impl.stats.LocalVectorCollectionStatsImpl;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import javax.annotation.Nonnull;

public class VectorCollectionProxy<K, V>
extends AbstractDistributedObject<VectorCollectionService>
implements VectorCollection<K, V> {
    private final String name;
    private final IPartitionService partitionService;
    private final SerializationService serializationService;
    private final VectorCollectionConfig config;
    private final LocalVectorCollectionStatsImpl statistics;
    private int putAllBatchSize;

    public VectorCollectionProxy(NodeEngine nodeEngine, VectorCollectionService service, String name, VectorCollectionConfig config) {
        super(nodeEngine, service);
        this.name = name;
        this.partitionService = nodeEngine.getPartitionService();
        this.serializationService = nodeEngine.getSerializationService();
        this.statistics = service.getStatistics(name);
        this.config = config;
    }

    @Override
    public String getName() {
        return this.name;
    }

    @Override
    public String getServiceName() {
        return "hz:service:vector";
    }

    @Override
    public CompletionStage<VectorDocument<V>> getAsync(@Nonnull K key) {
        Objects.requireNonNull(key, "key cannot be null");
        Object keyData = this.serializationService.toData(key);
        return this.invokeOnKeyOwnerAsyncAndDeserialize(new GetOperation(this.name, (Data)keyData), (Data)keyData, LocalVectorCollectionStatsImpl::incrementGetLatencyNanos);
    }

    @Override
    public CompletionStage<VectorDocument<V>> putAsync(@Nonnull K key, @Nonnull VectorDocument<V> value) {
        Objects.requireNonNull(key, "key cannot be null");
        Objects.requireNonNull(value, "value cannot be null");
        Object keyData = this.serializationService.toData(key);
        Object valueData = this.serializationService.toData(value.getValue());
        return this.invokeOnKeyOwnerAsyncAndDeserialize(new PutOperation(this.name, (Data)keyData, (Data)valueData, value.getVectors()), (Data)keyData, LocalVectorCollectionStatsImpl::incrementPutLatencyNanos);
    }

    @Override
    public CompletionStage<Void> setAsync(@Nonnull K key, @Nonnull VectorDocument<V> value) {
        Objects.requireNonNull(key, "key cannot be null");
        Objects.requireNonNull(value, "value cannot be null");
        Object keyData = this.serializationService.toData(key);
        Object valueData = this.serializationService.toData(value.getValue());
        return this.invokeOnKeyOwnerAsyncAndWrap(new SetOperation(this.name, (Data)keyData, (Data)valueData, value.getVectors()), (Data)keyData, LocalVectorCollectionStatsImpl::incrementSetLatencyNanos);
    }

    @Override
    public CompletionStage<VectorDocument<V>> putIfAbsentAsync(@Nonnull K key, @Nonnull VectorDocument<V> value) {
        Objects.requireNonNull(key, "key cannot be null");
        Objects.requireNonNull(value, "value cannot be null");
        Object keyData = this.serializationService.toData(key);
        Object valueData = this.serializationService.toData(value.getValue());
        return this.invokeOnKeyOwnerAsyncAndDeserialize(new PutIfAbsentOperation(this.name, (Data)keyData, (Data)valueData, value.getVectors()), (Data)keyData, LocalVectorCollectionStatsImpl::incrementPutLatencyNanos);
    }

    @Override
    public CompletionStage<Void> putAllAsync(Map<? extends K, VectorDocument<V>> documents) {
        Objects.requireNonNull(documents, "Null documents map is not allowed");
        long startTimeNanos = Timer.nanos();
        InternalCompletableFuture<Void> future = new InternalCompletableFuture<Void>();
        this.putAllInternalAsync(documents, future);
        return future.thenAcceptAsync(v -> this.statistics.incrementPutAllLatencyNanos(documents.size(), Timer.nanosElapsed(startTimeNanos)), ConcurrencyUtil.CALLER_RUNS);
    }

    @Override
    public CompletionStage<VectorDocument<V>> removeAsync(K key) {
        Objects.requireNonNull(key, "key cannot be null");
        Object keyData = this.serializationService.toData(key);
        return this.invokeOnKeyOwnerAsyncAndDeserialize(new RemoveOperation(this.name, (Data)keyData), (Data)keyData, LocalVectorCollectionStatsImpl::incrementRemoveLatencyNanos);
    }

    @Override
    public CompletionStage<Void> deleteAsync(K key) {
        Objects.requireNonNull(key, "key cannot be null");
        Object keyData = this.serializationService.toData(key);
        return this.invokeOnKeyOwnerAsyncAndWrap(new DeleteOperation(this.name, (Data)keyData), (Data)keyData, LocalVectorCollectionStatsImpl::incrementDeleteLatencyNanos);
    }

    @Override
    public CompletionStage<Void> optimizeAsync(String indexName) {
        long startTimeNanos = Timer.nanos();
        OptimizeOperationsFactory factory = new OptimizeOperationsFactory(this.name, indexName);
        return this.getOperationService().invokeOnAllPartitionsAsync("hz:service:vector", factory).thenAcceptAsync(res -> this.statistics.incrementOptimizeLatencyNanos(startTimeNanos), ConcurrencyUtil.CALLER_RUNS);
    }

    @Override
    public CompletionStage<Void> clearAsync() {
        long startTimeNanos = Timer.nanos();
        ClearOperationsFactory factory = new ClearOperationsFactory(this.name);
        return this.getOperationService().invokeOnAllPartitionsAsync("hz:service:vector", factory).thenAcceptAsync(res -> this.statistics.incrementClearLatencyNanos(startTimeNanos), ConcurrencyUtil.CALLER_RUNS);
    }

    @Override
    public long size() {
        try {
            long startTimeNanos = Timer.nanos();
            Map<Integer, Object> results = this.getOperationService().invokeOnAllPartitions("hz:service:vector", new SizeOperationsFactory(this.name));
            long sum = results.values().stream().mapToLong(o -> (Long)o).sum();
            this.statistics.incrementSizeLatencyNanos(startTimeNanos);
            return sum;
        }
        catch (Exception e) {
            throw ExceptionUtil.rethrow(e);
        }
    }

    @Override
    public CompletionStage<SearchResults<K, V>> searchAsync(VectorValues vectors, SearchOptions searchOptions) {
        Objects.requireNonNull(vectors, "vectors cannot be null");
        Objects.requireNonNull(searchOptions, "searchOptions cannot be null");
        long startTimeNanos = Timer.nanos();
        return ((VectorCollectionService)this.getService()).getSearcher(this.name, searchOptions).search(this.name, vectors, searchOptions).thenApplyAsync(merged -> {
            this.statistics.incrementSearchLatencyNanos(merged.size(), Timer.nanosElapsed(startTimeNanos));
            return VectorUtil.deserialize(merged, this.serializationService);
        }, ConcurrencyUtil.CALLER_RUNS);
    }

    public void setPutAllBatchSize(int putAllBatchSize) {
        this.putAllBatchSize = putAllBatchSize;
    }

    protected void putAllInternalAsync(Map<? extends K, ? extends VectorDocument<V>> map, @Nonnull InternalCompletableFuture<Void> future) {
        try {
            int mapSize = map.size();
            if (mapSize == 0) {
                future.complete(null);
                return;
            }
            boolean useBatching = this.isPutAllUseBatching(mapSize);
            int partitionCount = this.partitionService.getPartitionCount();
            int initialSize = this.getPutAllInitialSize(useBatching, mapSize, partitionCount);
            Map<Address, List<Integer>> memberPartitionsMap = this.partitionService.getMemberPartitionsMap();
            MutableLong[] counterPerMember = null;
            Address[] addresses = null;
            if (useBatching) {
                counterPerMember = new MutableLong[partitionCount];
                addresses = new Address[partitionCount];
                for (Map.Entry<Address, List<Integer>> addressListEntry : memberPartitionsMap.entrySet()) {
                    MutableLong counter = new MutableLong();
                    Address address = addressListEntry.getKey();
                    for (int partitionId : addressListEntry.getValue()) {
                        counterPerMember[partitionId] = counter;
                        addresses[partitionId] = address;
                    }
                }
            }
            AtomicInteger counter = new AtomicInteger(useBatching ? 1 : memberPartitionsMap.size());
            BiConsumer<Void, Throwable> callback = (response, t) -> {
                if (t != null) {
                    future.completeExceptionally((Throwable)t);
                }
                if (counter.decrementAndGet() == 0 && !future.isDone()) {
                    future.complete(null);
                }
            };
            VectorEntries[] entriesPerPartition = new VectorEntries[partitionCount];
            for (Map.Entry<K, VectorDocument<V>> entry : map.entrySet()) {
                long currentSize;
                Preconditions.checkNotNull(entry.getKey(), "key cannot be null");
                Preconditions.checkNotNull(entry.getValue(), "value cannot be null");
                Data keyData = this.toData(entry.getKey());
                int partitionId = this.partitionService.getPartitionId(keyData);
                VectorEntries entries = entriesPerPartition[partitionId];
                if (entries == null) {
                    entriesPerPartition[partitionId] = entries = new VectorEntries(initialSize);
                }
                entries.add(keyData, VectorUtil.serialize(entry.getValue(), this.serializationService));
                if (!useBatching || (currentSize = ++counterPerMember[partitionId].value) % (long)this.putAllBatchSize != 0L) continue;
                List<Integer> partitions = memberPartitionsMap.get(addresses[partitionId]);
                counter.incrementAndGet();
                this.invokePutAllOperation(addresses[partitionId], partitions, entriesPerPartition, true).whenCompleteAsync((BiConsumer)callback, ConcurrencyUtil.getDefaultAsyncExecutor());
            }
            for (Map.Entry<Object, Object> entry : memberPartitionsMap.entrySet()) {
                if (useBatching) {
                    counter.incrementAndGet();
                }
                this.invokePutAllOperation((Address)entry.getKey(), (List)entry.getValue(), entriesPerPartition, useBatching).whenCompleteAsync((BiConsumer)callback, ConcurrencyUtil.getDefaultAsyncExecutor());
            }
            if (useBatching) {
                callback.accept(null, null);
            }
        }
        catch (Throwable e) {
            throw ExceptionUtil.rethrow(e);
        }
    }

    private boolean isPutAllUseBatching(int mapSize) {
        return this.putAllBatchSize > 0 && mapSize > this.putAllBatchSize * this.getNodeEngine().getClusterService().getSize();
    }

    private int getPutAllInitialSize(boolean useBatching, int mapSize, int partitionCount) {
        if (mapSize == 1) {
            return 1;
        }
        if (useBatching) {
            return this.putAllBatchSize;
        }
        return (int)Math.ceil((double)(20.0f * (float)mapSize / (float)partitionCount) / Math.log10(mapSize));
    }

    @Nonnull
    private CompletableFuture<Void> invokePutAllOperation(Address address, List<Integer> memberPartitions, VectorEntries[] entriesPerPartition, boolean useBatching) {
        int size = memberPartitions.size();
        int[] partitions = new int[size];
        int index = 0;
        for (Integer partitionId : memberPartitions) {
            if (entriesPerPartition[partitionId] == null) continue;
            partitions[index++] = partitionId;
        }
        if (index == 0) {
            return InternalCompletableFuture.newCompletedFuture(null);
        }
        if (index < size) {
            partitions = Arrays.copyOf(partitions, index);
            size = index;
        }
        index = 0;
        VectorEntries[] entries = new VectorEntries[size];
        long totalSize = 0L;
        for (int partitionId : partitions) {
            int batchSize = entriesPerPartition[partitionId].size();
            assert (!useBatching || this.putAllBatchSize == 0 || batchSize <= this.putAllBatchSize);
            entries[index++] = entriesPerPartition[partitionId];
            totalSize += (long)batchSize;
            entriesPerPartition[partitionId] = null;
        }
        if (totalSize == 0L) {
            return InternalCompletableFuture.newCompletedFuture(null);
        }
        PutAllOperationFactory factory = new PutAllOperationFactory(this.name, partitions, entries);
        return this.getOperationService().invokeOnPartitionsAsync("hz:service:vector", (OperationFactory)factory, Collections.singletonMap(address, CollectionUtil.asIntegerList(partitions))).thenApplyAsync(v -> null, ConcurrencyUtil.CALLER_RUNS);
    }

    private CompletionStage<VectorDocument<V>> invokeOnKeyOwnerAsyncAndDeserialize(Operation op, Data keyData, BiConsumer<LocalVectorCollectionStatsImpl, Long> statsUpdater) {
        long startTimeNanos = Timer.nanos();
        return this.invokeOnKeyOwnerAsync(op, keyData).thenApplyAsync(dataVectorDocument -> {
            this.recordStats(startTimeNanos, statsUpdater);
            return VectorUtil.deserialize(dataVectorDocument, this.serializationService);
        }, ConcurrencyUtil.CALLER_RUNS);
    }

    private <T> CompletionStage<T> invokeOnKeyOwnerAsyncAndWrap(Operation op, Data keyData, BiConsumer<LocalVectorCollectionStatsImpl, Long> statsUpdater) {
        long startTimeNanos = Timer.nanos();
        return this.invokeOnKeyOwnerAsync(op, keyData).thenApplyAsync(v -> {
            this.recordStats(startTimeNanos, statsUpdater);
            return v;
        }, ConcurrencyUtil.CALLER_RUNS);
    }

    private void recordStats(long startTimeNanos, BiConsumer<LocalVectorCollectionStatsImpl, Long> statsUpdater) {
        statsUpdater.accept(this.statistics, Timer.nanosElapsed(startTimeNanos));
    }

    private <T> CompletionStage<T> invokeOnKeyOwnerAsync(Operation op, Data keyData) {
        return this.getOperationService().invokeOnPartitionAsync("hz:service:vector", op, this.getPartitionId(keyData));
    }
}

