/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.org.apache.calcite.sql2rel;

import com.hazelcast.com.google.common.collect.ImmutableList;
import com.hazelcast.org.apache.calcite.rel.RelHomogeneousShuttle;
import com.hazelcast.org.apache.calcite.rel.RelNode;
import com.hazelcast.org.apache.calcite.rel.core.CorrelationId;
import com.hazelcast.org.apache.calcite.rel.core.Filter;
import com.hazelcast.org.apache.calcite.rel.core.Project;
import com.hazelcast.org.apache.calcite.rel.logical.LogicalCorrelate;
import com.hazelcast.org.apache.calcite.rex.RexBuilder;
import com.hazelcast.org.apache.calcite.rex.RexCall;
import com.hazelcast.org.apache.calcite.rex.RexCorrelVariable;
import com.hazelcast.org.apache.calcite.rex.RexFieldAccess;
import com.hazelcast.org.apache.calcite.rex.RexInputRef;
import com.hazelcast.org.apache.calcite.rex.RexLocalRef;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.rex.RexOver;
import com.hazelcast.org.apache.calcite.rex.RexPatternFieldRef;
import com.hazelcast.org.apache.calcite.rex.RexShuttle;
import com.hazelcast.org.apache.calcite.rex.RexSubQuery;
import com.hazelcast.org.apache.calcite.rex.RexTableInputRef;
import com.hazelcast.org.apache.calcite.rex.RexVisitorImpl;
import com.hazelcast.org.apache.calcite.tools.RelBuilder;
import com.hazelcast.org.apache.calcite.tools.RelBuilderFactory;
import com.hazelcast.org.apache.calcite.util.ImmutableBitSet;
import com.hazelcast.org.checkerframework.checker.nullness.qual.Nullable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apiguardian.api.API;

@API(since="1.27", status=API.Status.EXPERIMENTAL)
public final class CorrelateProjectExtractor
extends RelHomogeneousShuttle {
    private final RelBuilderFactory builderFactory;

    public CorrelateProjectExtractor(RelBuilderFactory factory) {
        this.builderFactory = factory;
    }

    @Override
    public RelNode visit(LogicalCorrelate correlate) {
        List<Integer> retainFields;
        RelNode left = correlate.getLeft().accept(this);
        RelNode right = correlate.getRight().accept(this);
        int oldLeft = left.getRowType().getFieldCount();
        Set<RexNode> callsWithCorrelationInRight = CorrelateProjectExtractor.findCorrelationDependentCalls(correlate.getCorrelationId(), right);
        boolean isTrivialCorrelation = callsWithCorrelationInRight.stream().allMatch(exp -> exp instanceof RexFieldAccess);
        if (isTrivialCorrelation) {
            if (correlate.getLeft().equals(left) && correlate.getRight().equals(right)) {
                return correlate;
            }
            return correlate.copy(correlate.getTraitSet(), left, right, correlate.getCorrelationId(), correlate.getRequiredColumns(), correlate.getJoinType());
        }
        RelBuilder builder = this.builderFactory.create(correlate.getCluster(), null);
        builder.push(left);
        ArrayList<RexNode> callsWithCorrelationOverLeft = new ArrayList<RexNode>();
        for (RexNode rexNode : callsWithCorrelationInRight) {
            callsWithCorrelationOverLeft.add(CorrelateProjectExtractor.replaceCorrelationsWithInputRef(rexNode, builder));
        }
        builder.projectPlus(callsWithCorrelationOverLeft);
        HashMap<RexNode, RexNode> transformMapping = new HashMap<RexNode, RexNode>();
        for (RexNode callInRight : callsWithCorrelationInRight) {
            RexBuilder xb = builder.getRexBuilder();
            RexNode v = xb.makeCorrel(builder.peek().getRowType(), correlate.getCorrelationId());
            RexNode flatCorrelationInRight = xb.makeFieldAccess(v, oldLeft + transformMapping.size());
            transformMapping.put(callInRight, flatCorrelationInRight);
        }
        ImmutableList<RexNode> immutableList = builder.fields(ImmutableBitSet.range(oldLeft, oldLeft + callsWithCorrelationOverLeft.size()).asList());
        int newLeft = builder.fields().size();
        right = CorrelateProjectExtractor.replaceExpressionsUsingMap(right, transformMapping);
        builder.push(right);
        builder.correlate(correlate.getJoinType(), correlate.getCorrelationId(), immutableList);
        switch (correlate.getJoinType()) {
            case SEMI: 
            case ANTI: {
                retainFields = ImmutableBitSet.range(0, oldLeft).asList();
                break;
            }
            case LEFT: 
            case INNER: {
                retainFields = ImmutableBitSet.builder().set(0, oldLeft).set(newLeft, newLeft + right.getRowType().getFieldCount()).build().asList();
                break;
            }
            default: {
                throw new AssertionError((Object)correlate.getJoinType());
            }
        }
        builder.project(builder.fields(retainFields));
        return builder.build();
    }

    private static Set<RexNode> findCorrelationDependentCalls(CorrelationId corrId, RelNode plan) {
        final SimpleCorrelationCollector finder = new SimpleCorrelationCollector(corrId);
        plan.accept(new RelHomogeneousShuttle(){

            @Override
            public RelNode visit(RelNode other) {
                if (other instanceof Project || other instanceof Filter) {
                    other.accept(finder);
                }
                return super.visit(other);
            }
        });
        return finder.correlations;
    }

    private static RelNode replaceExpressionsUsingMap(RelNode plan, Map<RexNode, RexNode> mapping) {
        final CallReplacer replacer = new CallReplacer(mapping);
        return plan.accept(new RelHomogeneousShuttle(){

            @Override
            public RelNode visit(RelNode other) {
                RelNode mNode = super.visitChildren(other);
                return mNode.accept(replacer);
            }
        });
    }

    private static boolean isSimpleCorrelatedExpression(RexNode node, CorrelationId id) {
        Boolean r = node.accept(new SimpleCorrelationDetector(id));
        return r == null ? Boolean.FALSE : r;
    }

    private static RexNode replaceCorrelationsWithInputRef(RexNode exp, final RelBuilder b) {
        return exp.accept(new RexShuttle(){

            @Override
            public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
                if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {
                    return b.field(fieldAccess.getField().getIndex());
                }
                return super.visitFieldAccess(fieldAccess);
            }
        });
    }

    private static final class CallReplacer
    extends RexShuttle {
        private final Map<RexNode, RexNode> mapping;

        CallReplacer(Map<RexNode, RexNode> mapping) {
            this.mapping = mapping;
        }

        @Override
        public RexNode visitCall(RexCall oldCall) {
            RexNode newCall = this.mapping.get(oldCall);
            if (newCall != null) {
                return newCall;
            }
            return super.visitCall(oldCall);
        }
    }

    private static class SimpleCorrelationDetector
    extends RexVisitorImpl<Boolean> {
        private final CorrelationId corrId;

        private SimpleCorrelationDetector(CorrelationId corrId) {
            super(true);
            this.corrId = corrId;
        }

        @Override
        public Boolean visitOver(RexOver over) {
            return Boolean.FALSE;
        }

        @Override
        public Boolean visitSubQuery(RexSubQuery subQuery) {
            return Boolean.FALSE;
        }

        @Override
        public Boolean visitCall(RexCall call) {
            Boolean hasSimpleCorrelation = null;
            for (RexNode op : call.operands) {
                Boolean b = op.accept(this);
                if (b == null) continue;
                hasSimpleCorrelation = hasSimpleCorrelation == null ? b : hasSimpleCorrelation != false && b != false;
            }
            return hasSimpleCorrelation == null ? Boolean.FALSE : hasSimpleCorrelation;
        }

        @Override
        public @Nullable Boolean visitFieldAccess(RexFieldAccess fieldAccess) {
            return fieldAccess.getReferenceExpr().accept(this);
        }

        @Override
        public Boolean visitInputRef(RexInputRef inputRef) {
            return Boolean.FALSE;
        }

        @Override
        public Boolean visitCorrelVariable(RexCorrelVariable correlVariable) {
            return correlVariable.id.equals(this.corrId);
        }

        @Override
        public Boolean visitTableInputRef(RexTableInputRef ref) {
            return Boolean.FALSE;
        }

        @Override
        public Boolean visitLocalRef(RexLocalRef localRef) {
            return Boolean.FALSE;
        }

        @Override
        public Boolean visitPatternFieldRef(RexPatternFieldRef fieldRef) {
            return Boolean.FALSE;
        }
    }

    private static final class SimpleCorrelationCollector
    extends RexShuttle {
        private final CorrelationId correlationId;
        private final Set<RexNode> correlations = new LinkedHashSet<RexNode>();

        SimpleCorrelationCollector(CorrelationId corrId) {
            this.correlationId = corrId;
        }

        @Override
        public RexNode visitCall(RexCall call) {
            if (CorrelateProjectExtractor.isSimpleCorrelatedExpression(call, this.correlationId)) {
                this.correlations.add(call);
                return call;
            }
            return super.visitCall(call);
        }

        @Override
        public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
            if (CorrelateProjectExtractor.isSimpleCorrelatedExpression(fieldAccess, this.correlationId)) {
                this.correlations.add(fieldAccess);
                return fieldAccess;
            }
            return super.visitFieldAccess(fieldAccess);
        }
    }
}

