/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.optimize.program;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rel.core.Exchange;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinInfo;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.table.api.config.OptimizerConfigOptions;
import org.apache.flink.table.planner.plan.nodes.FlinkConventions;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalDynamicFilteringTableSourceScan;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalExchange;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalGroupAggregateBase;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalHashJoin;
import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalSortMergeJoin;
import org.apache.flink.table.planner.plan.nodes.physical.batch.runtimefilter.BatchPhysicalGlobalRuntimeFilterBuilder;
import org.apache.flink.table.planner.plan.nodes.physical.batch.runtimefilter.BatchPhysicalLocalRuntimeFilterBuilder;
import org.apache.flink.table.planner.plan.nodes.physical.batch.runtimefilter.BatchPhysicalRuntimeFilter;
import org.apache.flink.table.planner.plan.optimize.program.BatchOptimizeContext;
import org.apache.flink.table.planner.plan.optimize.program.FlinkOptimizeProgram;
import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution;
import org.apache.flink.table.planner.plan.utils.DefaultRelShuttle;
import org.apache.flink.table.planner.plan.utils.FlinkRelMdUtil;
import org.apache.flink.table.planner.plan.utils.JoinUtil;
import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.util.Preconditions;

public class FlinkRuntimeFilterProgram
implements FlinkOptimizeProgram<BatchOptimizeContext> {
    @Override
    public RelNode optimize(RelNode root, BatchOptimizeContext context) {
        if (!FlinkRuntimeFilterProgram.isRuntimeFilterEnabled(root)) {
            return root;
        }
        Preconditions.checkState((FlinkRuntimeFilterProgram.getMinProbeDataSize(root) > FlinkRuntimeFilterProgram.getMaxBuildDataSize(root) ? 1 : 0) != 0, (Object)"The min probe data size should be larger than the max build data size.");
        DefaultRelShuttle shuttle = new DefaultRelShuttle(){

            @Override
            public RelNode visit(RelNode rel) {
                if (!(rel instanceof Join)) {
                    ArrayList<RelNode> newInputs = new ArrayList<RelNode>();
                    for (RelNode input : rel.getInputs()) {
                        RelNode newInput = input.accept(this);
                        newInputs.add(newInput);
                    }
                    return rel.copy(rel.getTraitSet(), newInputs);
                }
                Join join = (Join)rel;
                RelNode newLeft = join.getLeft().accept(this);
                RelNode newRight = join.getRight().accept(this);
                return FlinkRuntimeFilterProgram.tryInjectRuntimeFilter((Join)join.copy(join.getTraitSet(), (List)Arrays.asList(newLeft, newRight)));
            }
        };
        return shuttle.visit(root);
    }

    private static Join tryInjectRuntimeFilter(Join join) {
        ImmutableIntList probeIndices;
        ImmutableIntList buildIndices;
        RelNode probeSide;
        RelNode buildSide;
        boolean leftIsBuild;
        if (!FlinkRuntimeFilterProgram.isSuitableJoinType(join.getJoinType())) {
            return join;
        }
        if (!(join instanceof BatchPhysicalHashJoin) && !(join instanceof BatchPhysicalSortMergeJoin)) {
            return join;
        }
        if (FlinkRuntimeFilterProgram.canBeProbeSide(join.getLeft())) {
            leftIsBuild = false;
        } else if (FlinkRuntimeFilterProgram.canBeProbeSide(join.getRight())) {
            leftIsBuild = true;
        } else {
            return join;
        }
        if (join.getJoinType() == JoinRelType.LEFT && !leftIsBuild) {
            return join;
        }
        if (join.getJoinType() == JoinRelType.RIGHT && leftIsBuild) {
            return join;
        }
        JoinInfo joinInfo = join.analyzeCondition();
        if (leftIsBuild) {
            buildSide = join.getLeft();
            probeSide = join.getRight();
            buildIndices = joinInfo.leftKeys;
            probeIndices = joinInfo.rightKeys;
        } else {
            buildSide = join.getRight();
            probeSide = join.getLeft();
            buildIndices = joinInfo.rightKeys;
            probeIndices = joinInfo.leftKeys;
        }
        Optional<BuildSideInfo> suitableBuildOpt = FlinkRuntimeFilterProgram.findSuitableBuildSide(buildSide, buildIndices, (build, indices) -> FlinkRuntimeFilterProgram.isSuitableDataSize(build, probeSide, indices, probeIndices));
        if (suitableBuildOpt.isPresent()) {
            BuildSideInfo suitableBuildInfo = suitableBuildOpt.get();
            RelNode newProbe = FlinkRuntimeFilterProgram.tryPushDownProbeAndInjectRuntimeFilter(probeSide, probeIndices, suitableBuildInfo, false);
            if (leftIsBuild) {
                return join.copy(join.getTraitSet(), (List)Arrays.asList(buildSide, newProbe));
            }
            return join.copy(join.getTraitSet(), (List)Arrays.asList(newProbe, buildSide));
        }
        return join;
    }

    private static RelNode createNewProbeWithRuntimeFilter(RelNode buildSide, RelNode probeSide, ImmutableIntList buildIndices, ImmutableIntList probeIndices) {
        Optional<Double> buildRowCountOpt = FlinkRuntimeFilterProgram.getEstimatedRowCount(buildSide);
        Preconditions.checkState((boolean)buildRowCountOpt.isPresent());
        int buildRowCount = buildRowCountOpt.get().intValue();
        int maxRowCount = (int)Math.ceil((double)FlinkRuntimeFilterProgram.getMaxBuildDataSize(buildSide) / FlinkRelMdUtil.binaryRowAverageSize(buildSide));
        double filterRatio = FlinkRuntimeFilterProgram.computeFilterRatio(buildSide, probeSide, buildIndices, probeIndices);
        String[] buildFiledNames = (String[])buildIndices.stream().map(buildSide.getRowType().getFieldNames()::get).toArray(String[]::new);
        BatchPhysicalLocalRuntimeFilterBuilder localBuilder = new BatchPhysicalLocalRuntimeFilterBuilder(buildSide.getCluster(), buildSide.getTraitSet(), buildSide, buildIndices.toIntArray(), buildFiledNames, buildRowCount, maxRowCount);
        BatchPhysicalGlobalRuntimeFilterBuilder globalBuilder = new BatchPhysicalGlobalRuntimeFilterBuilder(localBuilder.getCluster(), localBuilder.getTraitSet(), FlinkRuntimeFilterProgram.createExchange(localBuilder, FlinkRelDistribution.SINGLETON()), buildFiledNames, buildRowCount, maxRowCount);
        BatchPhysicalRuntimeFilter runtimeFilter = new BatchPhysicalRuntimeFilter(probeSide.getCluster(), probeSide.getTraitSet(), FlinkRuntimeFilterProgram.createExchange(globalBuilder, FlinkRelDistribution.BROADCAST_DISTRIBUTED()), probeSide, probeIndices.toIntArray(), filterRatio);
        return runtimeFilter;
    }

    private static Optional<BuildSideInfo> findSuitableBuildSide(RelNode rel, ImmutableIntList buildIndices, BiFunction<RelNode, ImmutableIntList, Boolean> buildSideChecker) {
        if (rel instanceof Exchange) {
            Exchange exchange = (Exchange)rel;
            if (!(exchange.getInput() instanceof BatchPhysicalRuntimeFilter) && buildSideChecker.apply(exchange.getInput(), buildIndices).booleanValue()) {
                return Optional.of(new BuildSideInfo(exchange.getInput(), buildIndices));
            }
        } else {
            if (rel instanceof BatchPhysicalRuntimeFilter) {
                return Optional.empty();
            }
            if (rel instanceof Calc) {
                Calc calc = (Calc)rel;
                RexProgram program = calc.getProgram();
                List<RexNode> projects = program.getProjectList().stream().map(program::expandLocalRef).collect(Collectors.toList());
                ImmutableIntList inputIndices = FlinkRuntimeFilterProgram.getInputIndices(projects, buildIndices);
                if (inputIndices.isEmpty()) {
                    return Optional.empty();
                }
                return FlinkRuntimeFilterProgram.findSuitableBuildSide(calc.getInput(), inputIndices, buildSideChecker);
            }
            if (rel instanceof Join) {
                Join join = (Join)rel;
                if (!FlinkRuntimeFilterProgram.isSuitableJoinType(join.getJoinType())) {
                    return Optional.empty();
                }
                Tuple2<ImmutableIntList, ImmutableIntList> tuple2 = FlinkRuntimeFilterProgram.getInputIndices(join, buildIndices);
                ImmutableIntList leftIndices = (ImmutableIntList)tuple2.f0;
                ImmutableIntList rightIndices = (ImmutableIntList)tuple2.f1;
                if (join.getJoinType() == JoinRelType.LEFT) {
                    rightIndices = ImmutableIntList.of();
                } else if (join.getJoinType() == JoinRelType.RIGHT) {
                    leftIndices = ImmutableIntList.of();
                }
                if (leftIndices.isEmpty() && rightIndices.isEmpty()) {
                    return Optional.empty();
                }
                boolean firstCheckLeft = !leftIndices.isEmpty() && join.getLeft() instanceof Exchange;
                Optional<BuildSideInfo> buildSideInfoOpt = Optional.empty();
                if (firstCheckLeft) {
                    buildSideInfoOpt = FlinkRuntimeFilterProgram.findSuitableBuildSide(join.getLeft(), leftIndices, buildSideChecker);
                    if (!buildSideInfoOpt.isPresent() && !rightIndices.isEmpty()) {
                        buildSideInfoOpt = FlinkRuntimeFilterProgram.findSuitableBuildSide(join.getRight(), rightIndices, buildSideChecker);
                    }
                    return buildSideInfoOpt;
                }
                if (!(rightIndices.isEmpty() || (buildSideInfoOpt = FlinkRuntimeFilterProgram.findSuitableBuildSide(join.getRight(), rightIndices, buildSideChecker)).isPresent() || leftIndices.isEmpty())) {
                    buildSideInfoOpt = FlinkRuntimeFilterProgram.findSuitableBuildSide(join.getLeft(), leftIndices, buildSideChecker);
                }
                return buildSideInfoOpt;
            }
            if (rel instanceof BatchPhysicalGroupAggregateBase) {
                BatchPhysicalGroupAggregateBase agg = (BatchPhysicalGroupAggregateBase)rel;
                int[] grouping = agg.grouping();
                for (int k : buildIndices) {
                    if (k < grouping.length) continue;
                    return Optional.empty();
                }
                return FlinkRuntimeFilterProgram.findSuitableBuildSide(agg.getInput(), ImmutableIntList.copyOf(buildIndices.stream().map(index -> agg.grouping()[index]).collect(Collectors.toList())), buildSideChecker);
            }
        }
        return Optional.empty();
    }

    private static RelNode tryPushDownProbeAndInjectRuntimeFilter(RelNode rel, ImmutableIntList probeIndices, BuildSideInfo buildSideInfo, boolean filterHasBenefit) {
        if (rel instanceof BatchPhysicalRuntimeFilter) {
            return rel;
        }
        if (rel instanceof Exchange) {
            Exchange exchange = (Exchange)rel;
            return exchange.copy(exchange.getTraitSet(), (List)Collections.singletonList(FlinkRuntimeFilterProgram.tryPushDownProbeAndInjectRuntimeFilter(exchange.getInput(), probeIndices, buildSideInfo, true)));
        }
        if (rel instanceof Calc) {
            Calc calc = (Calc)rel;
            RexProgram program = calc.getProgram();
            List<RexNode> projects = program.getProjectList().stream().map(program::expandLocalRef).collect(Collectors.toList());
            ImmutableIntList inputIndices = FlinkRuntimeFilterProgram.getInputIndices(projects, probeIndices);
            if (!inputIndices.isEmpty()) {
                return calc.copy(calc.getTraitSet(), (List)Collections.singletonList(FlinkRuntimeFilterProgram.tryPushDownProbeAndInjectRuntimeFilter(calc.getInput(), inputIndices, buildSideInfo, filterHasBenefit)));
            }
        } else if (rel instanceof Join) {
            Join join = (Join)rel;
            Tuple2<ImmutableIntList, ImmutableIntList> tuple2 = FlinkRuntimeFilterProgram.getInputIndices(join, probeIndices);
            ImmutableIntList leftIndices = (ImmutableIntList)tuple2.f0;
            ImmutableIntList rightIndices = (ImmutableIntList)tuple2.f1;
            if (!leftIndices.isEmpty() || !rightIndices.isEmpty()) {
                RelNode leftSide = join.getLeft();
                RelNode rightSide = join.getRight();
                if (!leftIndices.isEmpty()) {
                    leftSide = FlinkRuntimeFilterProgram.tryPushDownProbeAndInjectRuntimeFilter(leftSide, leftIndices, buildSideInfo, true);
                }
                if (!rightIndices.isEmpty()) {
                    rightSide = FlinkRuntimeFilterProgram.tryPushDownProbeAndInjectRuntimeFilter(rightSide, rightIndices, buildSideInfo, true);
                }
                return join.copy(join.getTraitSet(), (List)Arrays.asList(leftSide, rightSide));
            }
        } else if (rel instanceof BatchPhysicalGroupAggregateBase) {
            BatchPhysicalGroupAggregateBase agg = (BatchPhysicalGroupAggregateBase)rel;
            int[] grouping = agg.grouping();
            if (probeIndices.stream().allMatch(index -> index < grouping.length)) {
                return agg.copy(agg.getTraitSet(), Collections.singletonList(FlinkRuntimeFilterProgram.tryPushDownProbeAndInjectRuntimeFilter(agg.getInput(), ImmutableIntList.copyOf(probeIndices.stream().map(index -> agg.grouping()[index]).collect(Collectors.toList())), buildSideInfo, true)));
            }
        } else {
            BatchPhysicalDynamicFilteringTableSourceScan tableScan;
            HashSet<Integer> dynamicFilteringIndices;
            if (rel instanceof Union) {
                Union union = (Union)rel;
                ArrayList<RelNode> newInputs = new ArrayList<RelNode>();
                for (RelNode input : union.getInputs()) {
                    newInputs.add(FlinkRuntimeFilterProgram.tryPushDownProbeAndInjectRuntimeFilter(input, probeIndices, buildSideInfo, filterHasBenefit));
                }
                return union.copy(union.getTraitSet(), newInputs, union.all);
            }
            if (rel instanceof BatchPhysicalDynamicFilteringTableSourceScan && (dynamicFilteringIndices = new HashSet<Integer>((tableScan = (BatchPhysicalDynamicFilteringTableSourceScan)rel).dynamicFilteringIndices())).containsAll(probeIndices)) {
                return rel;
            }
        }
        if (filterHasBenefit) {
            return FlinkRuntimeFilterProgram.createNewProbeWithRuntimeFilter(FlinkRuntimeFilterProgram.ignoreExchange(buildSideInfo.buildSide), FlinkRuntimeFilterProgram.ignoreExchange(rel), buildSideInfo.buildIndices, probeIndices);
        }
        return rel;
    }

    private static BatchPhysicalExchange createExchange(RelNode input, FlinkRelDistribution newDistribution) {
        RelTraitSet newTraitSet = input.getCluster().getPlanner().emptyTraitSet().replace(FlinkConventions.BATCH_PHYSICAL()).replace(newDistribution);
        return new BatchPhysicalExchange(input.getCluster(), newTraitSet, input, newDistribution);
    }

    private static ImmutableIntList getInputIndices(List<RexNode> projects, ImmutableIntList outputIndices) {
        ArrayList<Integer> inputIndices = new ArrayList<Integer>();
        for (int k : outputIndices) {
            RexNode rexNode = projects.get(k);
            if (!(rexNode instanceof RexInputRef)) {
                return ImmutableIntList.of();
            }
            inputIndices.add(((RexInputRef)rexNode).getIndex());
        }
        return ImmutableIntList.copyOf(inputIndices);
    }

    private static Tuple2<ImmutableIntList, ImmutableIntList> getInputIndices(Join join, ImmutableIntList outputIndices) {
        JoinInfo joinInfo = join.analyzeCondition();
        Map<Integer, Integer> leftToRightJoinKeysMapping = FlinkRuntimeFilterProgram.createKeysMapping(joinInfo.leftKeys, joinInfo.rightKeys);
        Map<Integer, Integer> rightToLeftJoinKeysMapping = FlinkRuntimeFilterProgram.createKeysMapping(joinInfo.rightKeys, joinInfo.leftKeys);
        ArrayList<Integer> leftIndices = new ArrayList<Integer>();
        ArrayList<Integer> rightIndices = new ArrayList<Integer>();
        int leftFieldCnt = join.getLeft().getRowType().getFieldCount();
        for (int index : outputIndices) {
            if (index < leftFieldCnt) {
                leftIndices.add(index);
                if (!leftToRightJoinKeysMapping.containsKey(index)) continue;
                rightIndices.add(leftToRightJoinKeysMapping.get(index));
                continue;
            }
            int rightIndex = index - leftFieldCnt;
            rightIndices.add(rightIndex);
            if (!rightToLeftJoinKeysMapping.containsKey(rightIndex)) continue;
            leftIndices.add(rightToLeftJoinKeysMapping.get(rightIndex));
        }
        ImmutableIntList left = leftIndices.size() == outputIndices.size() ? ImmutableIntList.copyOf(leftIndices) : ImmutableIntList.of();
        ImmutableIntList right = rightIndices.size() == outputIndices.size() ? ImmutableIntList.copyOf(rightIndices) : ImmutableIntList.of();
        return Tuple2.of((Object)left, (Object)right);
    }

    private static Map<Integer, Integer> createKeysMapping(ImmutableIntList keyList1, ImmutableIntList keyList2) {
        Preconditions.checkState((keyList1.size() == keyList2.size() ? 1 : 0) != 0);
        HashMap<Integer, Integer> mapping = new HashMap<Integer, Integer>();
        for (int i = 0; i < keyList1.size(); ++i) {
            mapping.put(keyList1.get(i), keyList2.get(i));
        }
        return mapping;
    }

    private static boolean canBeProbeSide(RelNode rel) {
        Optional<Double> size = FlinkRuntimeFilterProgram.getEstimatedDataSize(rel);
        return size.isPresent() && size.get() >= (double)FlinkRuntimeFilterProgram.getMinProbeDataSize(rel);
    }

    private static boolean isSuitableDataSize(RelNode buildSide, RelNode probeSide, ImmutableIntList buildIndices, ImmutableIntList probeIndices) {
        Optional<Double> buildSize = FlinkRuntimeFilterProgram.getEstimatedDataSize(buildSide);
        Optional<Double> probeSize = FlinkRuntimeFilterProgram.getEstimatedDataSize(probeSide);
        long maxBuildDataSize = FlinkRuntimeFilterProgram.getMaxBuildDataSize(buildSide);
        long minProbeDataSize = FlinkRuntimeFilterProgram.getMinProbeDataSize(probeSide);
        double minFilterRatio = FlinkRuntimeFilterProgram.getMinFilterRatio(buildSide);
        if (!buildSize.isPresent() || !probeSize.isPresent()) {
            return false;
        }
        if (buildSize.get() > (double)maxBuildDataSize || probeSize.get() < (double)minProbeDataSize) {
            return false;
        }
        return FlinkRuntimeFilterProgram.computeFilterRatio(buildSide, probeSide, buildIndices, probeIndices) >= minFilterRatio;
    }

    private static double computeFilterRatio(RelNode buildSide, RelNode probeSide, ImmutableIntList buildIndices, ImmutableIntList probeIndices) {
        Optional<Double> buildNdv = FlinkRuntimeFilterProgram.getEstimatedNdv(buildSide, ImmutableBitSet.of(buildIndices));
        Optional<Double> probeNdv = FlinkRuntimeFilterProgram.getEstimatedNdv(probeSide, ImmutableBitSet.of(probeIndices));
        if (buildNdv.isPresent() && probeNdv.isPresent()) {
            return Math.max(0.0, 1.0 - buildNdv.get() / probeNdv.get());
        }
        Optional<Double> buildRowCount = FlinkRuntimeFilterProgram.getEstimatedRowCount(buildSide);
        Optional<Double> probeRowCount = FlinkRuntimeFilterProgram.getEstimatedRowCount(probeSide);
        Preconditions.checkState((buildRowCount.isPresent() && probeRowCount.isPresent() ? 1 : 0) != 0);
        return Math.max(0.0, 1.0 - buildRowCount.get() / probeRowCount.get());
    }

    private static RelNode ignoreExchange(RelNode relNode) {
        if (relNode instanceof Exchange) {
            return relNode.getInput(0);
        }
        return relNode;
    }

    private static Optional<Double> getEstimatedDataSize(RelNode relNode) {
        return Optional.ofNullable(JoinUtil.binaryRowRelNodeSize(relNode));
    }

    private static Optional<Double> getEstimatedRowCount(RelNode relNode) {
        RelMetadataQuery mq = relNode.getCluster().getMetadataQuery();
        return Optional.ofNullable(mq.getRowCount(relNode));
    }

    private static Optional<Double> getEstimatedNdv(RelNode relNode, ImmutableBitSet keys) {
        RelMetadataQuery mq = relNode.getCluster().getMetadataQuery();
        return Optional.ofNullable(mq.getDistinctRowCount(relNode, keys, null));
    }

    private static boolean isRuntimeFilterEnabled(RelNode relNode) {
        return (Boolean)ShortcutUtils.unwrapTableConfig(relNode).get(OptimizerConfigOptions.TABLE_OPTIMIZER_RUNTIME_FILTER_ENABLED);
    }

    private static long getMaxBuildDataSize(RelNode relNode) {
        return ((MemorySize)ShortcutUtils.unwrapTableConfig(relNode).get(OptimizerConfigOptions.TABLE_OPTIMIZER_RUNTIME_FILTER_MAX_BUILD_DATA_SIZE)).getBytes();
    }

    private static long getMinProbeDataSize(RelNode relNode) {
        return ((MemorySize)ShortcutUtils.unwrapTableConfig(relNode).get(OptimizerConfigOptions.TABLE_OPTIMIZER_RUNTIME_FILTER_MIN_PROBE_DATA_SIZE)).getBytes();
    }

    private static double getMinFilterRatio(RelNode relNode) {
        return (Double)ShortcutUtils.unwrapTableConfig(relNode).get(OptimizerConfigOptions.TABLE_OPTIMIZER_RUNTIME_FILTER_MIN_FILTER_RATIO);
    }

    public static boolean isSuitableJoinType(JoinRelType joinType) {
        return joinType == JoinRelType.INNER || joinType == JoinRelType.SEMI || joinType == JoinRelType.LEFT || joinType == JoinRelType.RIGHT;
    }

    private static class BuildSideInfo {
        private final RelNode buildSide;
        private final ImmutableIntList buildIndices;

        public BuildSideInfo(RelNode buildSide, ImmutableIntList buildIndices) {
            this.buildSide = (RelNode)Preconditions.checkNotNull((Object)buildSide);
            this.buildIndices = (ImmutableIntList)Preconditions.checkNotNull((Object)buildIndices);
        }
    }
}

