/*
 * Decompiled with CFR 0.152.
 */
package com.hazelcast.jet.sql.impl.opt.physical;

import com.hazelcast.jet.sql.impl.opt.Conventions;
import com.hazelcast.jet.sql.impl.opt.OptUtils;
import com.hazelcast.jet.sql.impl.opt.logical.JoinLogicalRel;
import com.hazelcast.jet.sql.impl.opt.metadata.WatermarkedFields;
import com.hazelcast.jet.sql.impl.opt.physical.ImmutableStreamToStreamJoinPhysicalRule;
import com.hazelcast.jet.sql.impl.opt.physical.MustNotExecutePhysicalRel;
import com.hazelcast.jet.sql.impl.opt.physical.StreamToStreamJoinPhysicalRel;
import com.hazelcast.org.apache.calcite.plan.RelOptRule;
import com.hazelcast.org.apache.calcite.plan.RelOptRuleCall;
import com.hazelcast.org.apache.calcite.plan.RelOptUtil;
import com.hazelcast.org.apache.calcite.plan.RelRule;
import com.hazelcast.org.apache.calcite.rel.RelNode;
import com.hazelcast.org.apache.calcite.rel.core.JoinRelType;
import com.hazelcast.org.apache.calcite.rex.RexCall;
import com.hazelcast.org.apache.calcite.rex.RexInputRef;
import com.hazelcast.org.apache.calcite.rex.RexLiteral;
import com.hazelcast.org.apache.calcite.rex.RexNode;
import com.hazelcast.org.apache.calcite.sql.SqlKind;
import com.hazelcast.org.apache.calcite.sql.type.SqlTypeName;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.immutables.value.Value;

@Value.Enclosing
public final class StreamToStreamJoinPhysicalRule
extends RelRule<RelRule.Config> {
    static final RelOptRule INSTANCE = new StreamToStreamJoinPhysicalRule(Config.DEFAULT);

    private StreamToStreamJoinPhysicalRule(Config config) {
        super(config);
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        JoinLogicalRel join = (JoinLogicalRel)call.rel(0);
        JoinRelType joinType = join.getJoinType();
        if (joinType != JoinRelType.INNER && joinType != JoinRelType.LEFT && joinType != JoinRelType.RIGHT) {
            call.transformTo(this.fail(join, "Stream to stream JOIN supports INNER and LEFT/RIGHT OUTER JOIN types"));
        }
        RelNode left = RelRule.convert(join.getLeft(), join.getLeft().getTraitSet().replace(Conventions.PHYSICAL));
        RelNode right = RelRule.convert(join.getRight(), join.getRight().getTraitSet().replace(Conventions.PHYSICAL));
        WatermarkedFields wmFields = this.watermarkedFields(join, OptUtils.metadataQuery(left).extractWatermarkedFields(left), OptUtils.metadataQuery(right).extractWatermarkedFields(right));
        HashMap<Integer, Map<Integer, Long>> postponeTimeMap = new HashMap<Integer, Map<Integer, Long>>();
        for (RexNode conjunction : RelOptUtil.conjunctions(join.getCondition())) {
            StreamToStreamJoinPhysicalRule.tryExtractTimeBound(conjunction, wmFields.getFieldIndexes(), postponeTimeMap);
        }
        int leftColumns = join.getLeft().getRowType().getFieldCount();
        boolean foundLeft = false;
        boolean foundRight = false;
        for (Map.Entry enOuter : postponeTimeMap.entrySet()) {
            Iterator innerIt = ((Map)enOuter.getValue()).entrySet().iterator();
            while (innerIt.hasNext()) {
                Map.Entry enInner = innerIt.next();
                if ((Integer)enOuter.getKey() < leftColumns) {
                    if ((Integer)enInner.getKey() < leftColumns) {
                        innerIt.remove();
                        continue;
                    }
                    foundLeft = true;
                    continue;
                }
                if ((Integer)enInner.getKey() >= leftColumns) {
                    innerIt.remove();
                    continue;
                }
                foundRight = true;
            }
        }
        if (!foundLeft || !foundRight) {
            call.transformTo(this.fail(join, "A stream-to-stream join must have a join condition constraining the maximum difference between time values of the joined tables in both directions"));
        }
        call.transformTo(new StreamToStreamJoinPhysicalRel(join.getCluster(), join.getTraitSet().replace(Conventions.PHYSICAL), left, right, join.getCondition(), join.getJoinType(), postponeTimeMap));
    }

    static void tryExtractTimeBound(RexNode condition, Set<Integer> wmFieldIndexes, Map<Integer, Map<Integer, Long>> postponeTimeMap) {
        boolean isLt;
        boolean isGt;
        switch (condition.getKind()) {
            case EQUALS: {
                isGt = true;
                isLt = true;
                break;
            }
            case GREATER_THAN: 
            case GREATER_THAN_OR_EQUAL: {
                isGt = true;
                isLt = false;
                break;
            }
            case LESS_THAN: 
            case LESS_THAN_OR_EQUAL: {
                isGt = false;
                isLt = true;
                break;
            }
            case IS_NOT_DISTINCT_FROM: {
                return;
            }
            case BETWEEN: {
                throw new RuntimeException("Unexpected BETWEEN");
            }
            default: {
                return;
            }
        }
        Integer[] positiveField = new Integer[]{null};
        Integer[] negativeField = new Integer[]{null};
        long[] constantsSum = new long[]{0L};
        if (!StreamToStreamJoinPhysicalRule.addAddends(((RexCall)condition).getOperands().get(0), positiveField, negativeField, constantsSum, false) || !StreamToStreamJoinPhysicalRule.addAddends(((RexCall)condition).getOperands().get(1), positiveField, negativeField, constantsSum, true)) {
            return;
        }
        if (positiveField[0] == null || negativeField[0] == null) {
            return;
        }
        if (!wmFieldIndexes.contains(positiveField[0]) || !wmFieldIndexes.contains(negativeField[0])) {
            return;
        }
        if (isLt) {
            postponeTimeMap.computeIfAbsent(negativeField[0], x -> new HashMap()).merge(positiveField[0], constantsSum[0], Long::min);
        }
        if (isGt) {
            postponeTimeMap.computeIfAbsent(positiveField[0], x -> new HashMap()).merge(negativeField[0], -constantsSum[0], Long::min);
        }
    }

    private static boolean addAddends(RexNode expr, Integer[] positiveField, Integer[] negativeField, long[] constantsSum, boolean inverse) {
        if (expr instanceof RexLiteral) {
            RexLiteral literal = (RexLiteral)expr;
            if (!SqlTypeName.DAY_INTERVAL_TYPES.contains((Object)literal.getType().getSqlTypeName()) && !SqlTypeName.INT_TYPES.contains((Object)literal.getType().getSqlTypeName())) {
                return false;
            }
            Long value = literal.getValueAs(Long.class);
            if (value == null) {
                return false;
            }
            constantsSum[0] = constantsSum[0] + (long)(inverse ? 1 : -1) * value;
            return true;
        }
        if (expr instanceof RexInputRef) {
            Integer[] field;
            Integer[] integerArray = field = inverse ? positiveField : negativeField;
            if (field[0] != null) {
                return false;
            }
            field[0] = ((RexInputRef)expr).getIndex();
            return true;
        }
        if (expr.getKind() == SqlKind.PLUS || expr.getKind() == SqlKind.MINUS) {
            boolean secondOperandInverse = expr.getKind() == SqlKind.MINUS ? !inverse : inverse;
            List<RexNode> operands = ((RexCall)expr).getOperands();
            return StreamToStreamJoinPhysicalRule.addAddends(operands.get(0), positiveField, negativeField, constantsSum, inverse) && StreamToStreamJoinPhysicalRule.addAddends(operands.get(1), positiveField, negativeField, constantsSum, secondOperandInverse);
        }
        return false;
    }

    private MustNotExecutePhysicalRel fail(RelNode node, String message) {
        return new MustNotExecutePhysicalRel(node.getCluster(), node.getTraitSet().replace(Conventions.PHYSICAL), node.getRowType(), message);
    }

    private WatermarkedFields watermarkedFields(JoinLogicalRel join, WatermarkedFields leftFields, WatermarkedFields rightFields) {
        int offset = join.getLeft().getRowType().getFieldList().size();
        Set<Integer> shiftedRightProps = rightFields.getFieldIndexes().stream().map(right -> right + offset).collect(Collectors.toSet());
        return leftFields.union(new WatermarkedFields(shiftedRightProps));
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableStreamToStreamJoinPhysicalRule.Config.builder().description(StreamToStreamJoinPhysicalRule.class.getSimpleName()).operandSupplier(b0 -> b0.operand(JoinLogicalRel.class).trait(Conventions.LOGICAL).inputs(b1 -> b1.operand(RelNode.class).predicate(OptUtils::isUnbounded).anyInputs(), b2 -> b2.operand(RelNode.class).predicate(OptUtils::isUnbounded).anyInputs())).build();

        @Override
        default public RelOptRule toRule() {
            return new StreamToStreamJoinPhysicalRule(this);
        }
    }
}

