/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.plan.rules.logical;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.rules.AggregateExtractProjectRule;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.flink.table.expressions.PlannerResolvedFieldReference;
import org.apache.flink.table.plan.logical.LogicalWindow;
import org.apache.flink.table.plan.logical.rel.LogicalTableAggregate;
import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate;
import org.apache.flink.table.plan.logical.rel.LogicalWindowTableAggregate;
import org.apache.flink.table.plan.logical.rel.TableAggregate;

public class ExtendedAggregateExtractProjectRule
extends AggregateExtractProjectRule {
    public static final ExtendedAggregateExtractProjectRule INSTANCE = new ExtendedAggregateExtractProjectRule(ExtendedAggregateExtractProjectRule.operand(SingleRel.class, ExtendedAggregateExtractProjectRule.operand(RelNode.class, ExtendedAggregateExtractProjectRule.any()), new RelOptRuleOperand[0]), RelFactories.LOGICAL_BUILDER);

    public ExtendedAggregateExtractProjectRule(RelOptRuleOperand operand, RelBuilderFactory builderFactory) {
        super(operand, builderFactory);
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        SingleRel relNode = (SingleRel)call.rel(0);
        return relNode instanceof LogicalWindowAggregate || relNode instanceof LogicalAggregate || relNode instanceof TableAggregate;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Object relNode = call.rel(0);
        Object input = call.rel(1);
        RelBuilder relBuilder = call.builder().push((RelNode)input);
        if (relNode instanceof Aggregate) {
            call.transformTo(this.performExtractForAggregate((Aggregate)relNode, (RelNode)input, relBuilder));
        } else if (relNode instanceof TableAggregate) {
            call.transformTo(this.performExtractForTableAggregate((TableAggregate)relNode, (RelNode)input, relBuilder));
        }
    }

    private RelNode performExtractForAggregate(Aggregate aggregate, RelNode input, RelBuilder relBuilder) {
        Mapping mapping = this.extractProjectsAndMapping(aggregate, input, relBuilder);
        return this.getNewAggregate(aggregate, relBuilder, mapping);
    }

    private RelNode performExtractForTableAggregate(TableAggregate aggregate, RelNode input, RelBuilder relBuilder) {
        RelNode newAggregate = this.performExtractForAggregate(aggregate.getCorrespondingAggregate(), input, relBuilder);
        if (aggregate instanceof LogicalTableAggregate) {
            return LogicalTableAggregate.create((Aggregate)newAggregate);
        }
        return LogicalWindowTableAggregate.create((LogicalWindowAggregate)newAggregate);
    }

    private Mapping extractProjectsAndMapping(Aggregate aggregate, RelNode input, RelBuilder relBuilder) {
        ImmutableBitSet.Builder inputFieldsUsed = this.getInputFieldUsed(aggregate, input);
        ArrayList<RexInputRef> projects = new ArrayList<RexInputRef>();
        Mapping mapping = Mappings.create(MappingType.INVERSE_SURJECTION, aggregate.getInput().getRowType().getFieldCount(), inputFieldsUsed.cardinality());
        int j2 = 0;
        for (int i : inputFieldsUsed.build()) {
            projects.add(relBuilder.field(i));
            mapping.set(i, j2++);
        }
        if (input instanceof Project) {
            relBuilder.project(projects);
        } else {
            relBuilder.project(projects, Collections.emptyList(), true);
        }
        return mapping;
    }

    private ImmutableBitSet.Builder getInputFieldUsed(Aggregate aggregate, RelNode input) {
        ImmutableBitSet.Builder inputFieldsUsed = aggregate.getGroupSet().rebuild();
        for (AggregateCall aggCall : aggregate.getAggCallList()) {
            for (int i : aggCall.getArgList()) {
                inputFieldsUsed.set(i);
            }
            if (aggCall.filterArg < 0) continue;
            inputFieldsUsed.set(aggCall.filterArg);
        }
        if (aggregate instanceof LogicalWindowAggregate) {
            inputFieldsUsed.set(this.getWindowTimeFieldIndex(((LogicalWindowAggregate)aggregate).getWindow(), input));
        }
        return inputFieldsUsed;
    }

    private RelNode getNewAggregate(Aggregate oldAggregate, RelBuilder relBuilder, Mapping mapping) {
        ImmutableBitSet newGroupSet = Mappings.apply(mapping, oldAggregate.getGroupSet());
        Iterable newGroupSets = oldAggregate.getGroupSets().stream().map(bitSet -> Mappings.apply(mapping, bitSet)).collect(Collectors.toList());
        List<RelBuilder.AggCall> newAggCallList = this.getNewAggCallList(oldAggregate, relBuilder, mapping);
        RelBuilder.GroupKey groupKey = relBuilder.groupKey(newGroupSet, newGroupSets);
        if (oldAggregate instanceof LogicalWindowAggregate) {
            if (newGroupSet.size() == 0 && newAggCallList.size() == 0) {
                return oldAggregate;
            }
            relBuilder.aggregate(groupKey, (Iterable<RelBuilder.AggCall>)newAggCallList);
            Aggregate newAggregate = (Aggregate)relBuilder.build();
            LogicalWindowAggregate oldLogicalWindowAggregate = (LogicalWindowAggregate)oldAggregate;
            return LogicalWindowAggregate.create(oldLogicalWindowAggregate.getWindow(), oldLogicalWindowAggregate.getNamedProperties(), newAggregate);
        }
        relBuilder.aggregate(groupKey, (Iterable<RelBuilder.AggCall>)newAggCallList);
        return relBuilder.build();
    }

    private int getWindowTimeFieldIndex(LogicalWindow logicalWindow, RelNode input) {
        PlannerResolvedFieldReference timeAttribute = (PlannerResolvedFieldReference)logicalWindow.timeAttribute();
        return input.getRowType().getFieldNames().indexOf(timeAttribute.name());
    }

    private List<RelBuilder.AggCall> getNewAggCallList(Aggregate oldAggregate, RelBuilder relBuilder, Mapping mapping) {
        ArrayList<RelBuilder.AggCall> newAggCallList = new ArrayList<RelBuilder.AggCall>();
        for (AggregateCall aggCall : oldAggregate.getAggCallList()) {
            RexInputRef filterArg = aggCall.filterArg < 0 ? null : relBuilder.field(Mappings.apply((Mappings.TargetMapping)mapping, aggCall.filterArg));
            newAggCallList.add(relBuilder.aggregateCall(aggCall.getAggregation(), relBuilder.fields(Mappings.apply2(mapping, aggCall.getArgList()))).distinct(aggCall.isDistinct()).filter(filterArg).approximate(aggCall.isApproximate()).sort(relBuilder.fields(aggCall.collation)).as(aggCall.name));
        }
        return newAggCallList;
    }
}

