/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.vector.impl.query;

import com.hazelcast.client.impl.ClientEndpoint;
import com.hazelcast.cluster.Address;
import com.hazelcast.core.HazelcastInstanceNotActiveException;
import com.hazelcast.core.MemberLeftException;
import com.hazelcast.instance.impl.NodeState;
import com.hazelcast.internal.partition.IPartitionService;
import com.hazelcast.internal.serialization.Data;
import com.hazelcast.internal.util.ConcurrencyUtil;
import com.hazelcast.internal.util.ExceptionUtil;
import com.hazelcast.internal.util.JVMUtil;
import com.hazelcast.internal.util.collection.PartitionIdSet;
import com.hazelcast.logging.ILogger;
import com.hazelcast.shaded.io.github.jbellis.jvector.annotations.VisibleForTesting;
import com.hazelcast.shaded.io.github.jbellis.jvector.util.GrowableLongHeap;
import com.hazelcast.spi.exception.TargetNotMemberException;
import com.hazelcast.spi.impl.NodeEngine;
import com.hazelcast.spi.impl.operationservice.Operation;
import com.hazelcast.vector.SearchOptions;
import com.hazelcast.vector.SearchResults;
import com.hazelcast.vector.VectorValues;
import com.hazelcast.vector.impl.ops.SearchMemberOperation;
import com.hazelcast.vector.impl.ops.SearchOperation;
import com.hazelcast.vector.impl.query.QueryResult;
import com.hazelcast.vector.impl.query.Searcher;
import com.hazelcast.vector.impl.storage.AbstractVectorIndex;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.Supplier;
import javax.annotation.Nullable;

public class TwoStageSearcher
implements Searcher {
    public static final long FIXED_HEAP_BYTES_USED = (long)JVMUtil.OBJECT_HEADER_SIZE + 4L * (long)JVMUtil.REFERENCE_COST_IN_BYTES;
    private final IPartitionService partitionService;
    private final OperationInvoker operationInvoker;
    private final ILogger logger;
    private final Supplier<Boolean> activityStatusSupplier;

    public TwoStageSearcher(final NodeEngine engine) {
        this.partitionService = engine.getPartitionService();
        this.operationInvoker = new OperationInvoker(){

            @Override
            public <E> CompletableFuture<E> invokeOnTargetAsync(Operation op, Address target) {
                return engine.getOperationService().invokeOnTargetAsync("hz:service:vector", op, target);
            }

            @Override
            public <E> CompletableFuture<E> invokeOnPartitionAsync(Operation op, int partition) {
                return engine.getOperationService().invokeOnPartitionAsync("hz:service:vector", op, partition);
            }
        };
        this.activityStatusSupplier = () -> engine.getNode().getState() != NodeState.SHUT_DOWN;
        this.logger = engine.getLogger(AbstractVectorIndex.class);
    }

    private TwoStageSearcher(IPartitionService partitionService, OperationInvoker operationInvoker, ILogger logger) {
        this.partitionService = partitionService;
        this.operationInvoker = operationInvoker;
        this.logger = logger;
        this.activityStatusSupplier = () -> true;
    }

    @Override
    public CompletableFuture<SearchResults<Data, Data>> search(String collectionName, VectorValues vectors, SearchOptions options, @Nullable ClientEndpoint endpoint) {
        int partitionCount = this.getPartitionCount();
        Map<Address, List<Integer>> partitionsMap = this.partitionService.getMemberPartitionsMap();
        int membersCount = partitionsMap.size();
        QueryResult clusterResult = new QueryResult(partitionCount, membersCount + 1, options.getLimit(), GrowableLongHeap::new);
        CompletableFuture[] memberFutures = (CompletableFuture[])partitionsMap.entrySet().stream().map(entry -> {
            Address member = (Address)entry.getKey();
            PartitionIdSet partitions = new PartitionIdSet(partitionCount, (Collection)entry.getValue());
            CompletableFuture future = this.operationInvoker.invokeOnTargetAsync(new SearchMemberOperation(collectionName, vectors, options, partitions).setCallerUuid(endpoint != null ? endpoint.getUuid() : null), member);
            return future.handleAsync((r, e) -> {
                if (e == null) {
                    int firstPartitionId = r.scannedPartitions().firstPartition();
                    if (firstPartitionId >= 0) {
                        clusterResult.addResult(firstPartitionId, r.results());
                    }
                    return r.scannedPartitions();
                }
                if (TwoStageSearcher.isRoutineOperationException(e)) {
                    if (this.logger.isFineEnabled()) {
                        this.logger.fine("Member search failed for partitions " + String.valueOf(partitions) + " on member " + String.valueOf(member), (Throwable)e);
                    }
                } else {
                    this.logger.warning("Member search failed for partitions " + String.valueOf(partitions) + " on member " + String.valueOf(member), (Throwable)e);
                }
                return new PartitionIdSet(0);
            }, ConcurrencyUtil.CALLER_RUNS);
        }).toArray(CompletableFuture[]::new);
        return ((CompletableFuture)CompletableFuture.allOf(memberFutures).thenComposeAsync(v -> this.retryPartitions(memberFutures, clusterResult, collectionName, vectors, options, endpoint), ConcurrencyUtil.CALLER_RUNS)).thenApplyAsync(v -> clusterResult.complete(), ConcurrencyUtil.CALLER_RUNS);
    }

    private int getPartitionCount() {
        return this.partitionService.getPartitionCount();
    }

    private CompletionStage<Void> retryPartitions(CompletableFuture<PartitionIdSet>[] memberFutures, QueryResult clusterResult, String collectionName, VectorValues vectors, SearchOptions options, @Nullable ClientEndpoint endpoint) {
        PartitionIdSet finishedPartitions = new PartitionIdSet(this.getPartitionCount());
        for (CompletableFuture<PartitionIdSet> f : memberFutures) {
            finishedPartitions.addAll((PartitionIdSet)f.getNow(null));
        }
        if (!finishedPartitions.isMissingPartitions()) {
            return CompletableFuture.completedFuture(null);
        }
        if (!this.activityStatusSupplier.get().booleanValue()) {
            this.logger.fine("Not retrying partitions, instance was shut down");
            return CompletableFuture.failedFuture(new HazelcastInstanceNotActiveException());
        }
        PartitionIdSet missingPartitions = new PartitionIdSet(finishedPartitions);
        missingPartitions.complement();
        this.logger.fine("Retrying partitions %s", missingPartitions);
        CompletableFuture[] retryFutures = (CompletableFuture[])missingPartitions.stream().map(partitionId -> {
            CompletableFuture future = this.operationInvoker.invokeOnPartitionAsync(new SearchOperation(collectionName, vectors, options).setCallerUuid(endpoint != null ? endpoint.getUuid() : null), (int)partitionId);
            return future.handleAsync((r, t) -> {
                if (t == null) {
                    clusterResult.addResult((int)partitionId, (SearchResults<Data, Data>)r);
                    return null;
                }
                if (TwoStageSearcher.isRoutineOperationException(t)) {
                    this.logger.fine("Query failed during partition retry", (Throwable)t);
                } else {
                    this.logger.warning("Query failed during partition retry", (Throwable)t);
                }
                throw ExceptionUtil.sneakyThrow(t);
            }, ConcurrencyUtil.CALLER_RUNS);
        }).toArray(CompletableFuture[]::new);
        return CompletableFuture.allOf(retryFutures);
    }

    private static boolean isRoutineOperationException(Throwable e) {
        return e instanceof HazelcastInstanceNotActiveException || e instanceof MemberLeftException || e instanceof TargetNotMemberException;
    }

    @VisibleForTesting
    static interface OperationInvoker {
        public <E> CompletableFuture<E> invokeOnTargetAsync(Operation var1, Address var2);

        public <E> CompletableFuture<E> invokeOnPartitionAsync(Operation var1, int var2);
    }
}

