/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBeans;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Util;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableCollection;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableList;
import org.apache.flink.calcite.shaded.com.google.common.collect.Sets;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.planner.plan.utils.TemporalJoinUtil;

public abstract class FlinkFilterJoinRule<C extends Config>
extends RelRule<C>
implements TransformationRule {
    public static final FlinkFilterIntoJoinRule FILTER_INTO_JOIN = FlinkFilterIntoJoinRule.Config.DEFAULT.toRule();
    public static final FlinkJoinConditionPushRule JOIN_CONDITION_PUSH = FlinkJoinConditionPushRule.Config.DEFAULT.toRule();

    protected FlinkFilterJoinRule(C config) {
        super(config);
    }

    protected void perform(RelOptRuleCall call, Filter filter, Join join) {
        List<RexNode> joinFilters = RelOptUtil.conjunctions(join.getCondition());
        ImmutableList<RexNode> origJoinFilters = ImmutableList.copyOf(joinFilters);
        if (filter == null && joinFilters.isEmpty()) {
            return;
        }
        ArrayList<RexNode> aboveFilters = filter != null ? this.getConjunctions(filter) : new ArrayList();
        ImmutableList<RexNode> origAboveFilters = ImmutableList.copyOf(aboveFilters);
        JoinRelType joinType = join.getJoinType();
        if (((Config)this.config).isSmart() && !origAboveFilters.isEmpty() && join.getJoinType() != JoinRelType.INNER) {
            joinType = RelOptUtil.simplifyJoin(join, origAboveFilters, joinType);
        }
        ArrayList<RexNode> leftFilters = new ArrayList<RexNode>();
        ArrayList<RexNode> rightFilters = new ArrayList<RexNode>();
        boolean filterPushed = false;
        if (RelOptUtil.classifyFilters(join, aboveFilters, joinType, true, !joinType.generatesNullsOnLeft(), !joinType.generatesNullsOnRight(), joinFilters, leftFilters, rightFilters)) {
            filterPushed = true;
        }
        this.validateJoinFilters(aboveFilters, joinFilters, join, joinType);
        if (leftFilters.isEmpty() && rightFilters.isEmpty() && joinFilters.size() == origJoinFilters.size() && aboveFilters.size() == origAboveFilters.size() && Sets.newHashSet(joinFilters).equals(Sets.newHashSet(origJoinFilters))) {
            filterPushed = false;
        }
        if (joinType != JoinRelType.ANTI && RelOptUtil.classifyFilters(join, joinFilters, joinType, false, !joinType.generatesNullsOnRight(), !joinType.generatesNullsOnLeft(), joinFilters, leftFilters, rightFilters)) {
            filterPushed = true;
        }
        if (!filterPushed && joinType == join.getJoinType() || joinFilters.isEmpty() && leftFilters.isEmpty() && rightFilters.isEmpty()) {
            return;
        }
        RexBuilder rexBuilder = join.getCluster().getRexBuilder();
        ImmutableCollection fieldTypes = ((ImmutableList.Builder)((ImmutableList.Builder)ImmutableList.builder().addAll(RelOptUtil.getFieldTypeList(join.getLeft().getRowType()))).addAll(RelOptUtil.getFieldTypeList(join.getRight().getRowType()))).build();
        RexNode joinFilter = RexUtil.composeConjunction(rexBuilder, RexUtil.fixUp(rexBuilder, joinFilters, (List<RelDataType>)((Object)fieldTypes)));
        this.pushFiltersToAnotherSide(join, joinType, origAboveFilters, joinFilter, leftFilters, rightFilters, Arrays.asList(JoinRelType.INNER, JoinRelType.LEFT, JoinRelType.RIGHT));
        this.pushFiltersToAnotherSide(join, joinType, origJoinFilters, null, leftFilters, rightFilters, Collections.singletonList(JoinRelType.INNER));
        RelBuilder relBuilder = call.builder();
        RelNode leftRel = relBuilder.push(join.getLeft()).filter(leftFilters).build();
        RelNode rightRel = relBuilder.push(join.getRight()).filter(rightFilters).build();
        if (joinFilter.isAlwaysTrue() && leftFilters.isEmpty() && rightFilters.isEmpty() && joinType == join.getJoinType()) {
            return;
        }
        Join newJoinRel = join.copy(join.getTraitSet(), joinFilter, leftRel, rightRel, joinType, join.isSemiJoinDone());
        call.getPlanner().onCopy(join, newJoinRel);
        if (!leftFilters.isEmpty()) {
            call.getPlanner().onCopy(filter, leftRel);
        }
        if (!rightFilters.isEmpty()) {
            call.getPlanner().onCopy(filter, rightRel);
        }
        relBuilder.push(newJoinRel);
        relBuilder.convert(join.getRowType(), false);
        relBuilder.filter(RexUtil.fixUp(rexBuilder, aboveFilters, RelOptUtil.getFieldTypeList(relBuilder.peek().getRowType())));
        call.transformTo(relBuilder.build());
    }

    private List<RexNode> getConjunctions(Filter filter) {
        List<RexNode> conjunctions = RelOptUtil.conjunctions(filter.getCondition());
        RexBuilder rexBuilder = filter.getCluster().getRexBuilder();
        for (int i = 0; i < conjunctions.size(); ++i) {
            RexNode node = conjunctions.get(i);
            if (!(node instanceof RexCall)) continue;
            conjunctions.set(i, RelOptUtil.collapseExpandedIsNotDistinctFromExpr((RexCall)node, rexBuilder));
        }
        return conjunctions;
    }

    protected void validateJoinFilters(List<RexNode> aboveFilters, List<RexNode> joinFilters, Join join, JoinRelType joinType) {
        Iterator<RexNode> filterIter = joinFilters.iterator();
        while (filterIter.hasNext()) {
            RexNode exp = filterIter.next();
            if (((Config)this.config).getPredicate().apply(join, joinType, exp) || !joinType.projectsRight()) continue;
            aboveFilters.add(exp);
            filterIter.remove();
        }
    }

    private void pushFiltersToAnotherSide(Join joinRel, JoinRelType joinType, List<RexNode> filtersToPush, @Nullable RexNode joinFilter, List<RexNode> leftFilters, List<RexNode> rightFilters, List<JoinRelType> expectedJoinTypes) {
        if (filtersToPush.isEmpty() || !expectedJoinTypes.contains((Object)joinType)) {
            return;
        }
        JoinInfo joinInfo = joinRel.analyzeCondition();
        if (joinInfo.leftSet().isEmpty()) {
            if (joinFilter == null) {
                return;
            }
            joinInfo = JoinInfo.of(joinRel.getLeft(), joinRel.getRight(), joinFilter);
            if (joinInfo.leftSet().isEmpty()) {
                return;
            }
        }
        int leftFieldCnt = joinRel.getLeft().getRowType().getFieldList().size();
        ImmutableBitSet rightKeyBitsWithOffset = ImmutableBitSet.of(joinInfo.rightKeys.stream().map(i -> i + leftFieldCnt).collect(Collectors.toList()));
        for (RexNode filter : filtersToPush) {
            ImmutableIntList rightKeysWithOffset;
            RexNode shiftedFilter;
            RelOptUtil.InputFinder inputFinder = RelOptUtil.InputFinder.analyze(filter);
            ImmutableBitSet inputBits = inputFinder.build();
            if (filter.isAlwaysTrue()) continue;
            if (joinInfo.leftSet().contains(inputBits)) {
                RexNode shiftedFilter2 = this.remapFilter(joinInfo.leftKeys, joinInfo.rightKeys, joinRel.getRight().getRowType(), filter);
                if (rightFilters.contains(shiftedFilter2)) continue;
                rightFilters.add(shiftedFilter2);
                continue;
            }
            if (!rightKeyBitsWithOffset.contains(inputBits) || leftFilters.contains(shiftedFilter = this.remapFilter(rightKeysWithOffset = ImmutableIntList.copyOf(joinInfo.rightKeys.stream().map(i -> i + leftFieldCnt).collect(Collectors.toList())), joinInfo.leftKeys, joinRel.getLeft().getRowType(), filter))) continue;
            leftFilters.add(shiftedFilter);
        }
    }

    private RexNode remapFilter(ImmutableIntList oldKeys, ImmutableIntList newKeys, final RelDataType newInputType, RexNode filter) {
        final HashMap<Integer, Integer> mapping = new HashMap<Integer, Integer>();
        for (int i = 0; i < oldKeys.size(); ++i) {
            mapping.put(oldKeys.get(i), newKeys.get(i));
        }
        RexShuttle shuttle = new RexShuttle(){

            @Override
            public RexNode visitInputRef(RexInputRef inputRef) {
                int newIndex = mapping.getOrDefault(inputRef.getIndex(), -1);
                if (newIndex < 0) {
                    throw new TableException("should not happen");
                }
                return new RexInputRef(newIndex, newInputType.getFieldList().get(newIndex).getType());
            }
        };
        return filter.accept(shuttle);
    }

    protected boolean isEventTimeTemporalJoin(RexNode joinCondition) {
        RexVisitorImpl<Void> temporalConditionFinder = new RexVisitorImpl<Void>(true){

            @Override
            public Void visitCall(RexCall call) {
                if (call.getOperator() == TemporalJoinUtil.INITIAL_TEMPORAL_JOIN_CONDITION() && TemporalJoinUtil.isInitialRowTimeTemporalTableJoin(call)) {
                    throw new Util.FoundOne(call);
                }
                return (Void)super.visitCall(call);
            }
        };
        try {
            joinCondition.accept(temporalConditionFinder);
        }
        catch (Util.FoundOne found) {
            return true;
        }
        return false;
    }

    public static interface Config
    extends RelRule.Config {
        @ImmutableBeans.Property
        @ImmutableBeans.BooleanDefault(value=false)
        public boolean isSmart();

        public Config withSmart(boolean var1);

        @ImmutableBeans.Property
        public Predicate getPredicate();

        public Config withPredicate(Predicate var1);
    }

    @FunctionalInterface
    public static interface Predicate {
        public boolean apply(Join var1, JoinRelType var2, RexNode var3);
    }

    public static class FlinkFilterIntoJoinRule
    extends FlinkFilterJoinRule<Config> {
        protected FlinkFilterIntoJoinRule(Config config) {
            super(config);
        }

        @Override
        public boolean matches(RelOptRuleCall call) {
            Join join = (Join)call.rel(1);
            return !this.isEventTimeTemporalJoin(join.getCondition()) && super.matches(call);
        }

        @Override
        public void onMatch(RelOptRuleCall call) {
            Filter filter = (Filter)call.rel(0);
            Join join = (Join)call.rel(1);
            this.perform(call, filter, join);
        }

        public static interface Config
        extends org.apache.flink.table.planner.plan.rules.logical.FlinkFilterJoinRule$Config {
            public static final Config DEFAULT = EMPTY.withOperandSupplier(b0 -> b0.operand(Filter.class).oneInput(b1 -> b1.operand(Join.class).anyInputs())).as(Config.class).withSmart(true).withPredicate((join, joinType, exp) -> true).as(Config.class);

            @Override
            default public FlinkFilterIntoJoinRule toRule() {
                return new FlinkFilterIntoJoinRule(this);
            }
        }
    }

    public static class FlinkJoinConditionPushRule
    extends FlinkFilterJoinRule<Config> {
        protected FlinkJoinConditionPushRule(Config config) {
            super(config);
        }

        @Override
        public boolean matches(RelOptRuleCall call) {
            Join join = (Join)call.rel(0);
            return !this.isEventTimeTemporalJoin(join.getCondition()) && super.matches(call);
        }

        @Override
        public void onMatch(RelOptRuleCall call) {
            Join join = (Join)call.rel(0);
            this.perform(call, null, join);
        }

        public static interface Config
        extends org.apache.flink.table.planner.plan.rules.logical.FlinkFilterJoinRule$Config {
            public static final Config DEFAULT = EMPTY.withOperandSupplier(b -> b.operand(Join.class).anyInputs()).as(Config.class).withSmart(true).withPredicate((join, joinType, exp) -> true).as(Config.class);

            @Override
            default public FlinkJoinConditionPushRule toRule() {
                return new FlinkJoinConditionPushRule(this);
            }
        }
    }
}

