in flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkAggregateExpandDistinctAggregatesRule.java [416:569]
private void rewriteUsingGroupingSets(RelOptRuleCall call,
Aggregate aggregate) {
final Set<ImmutableBitSet> groupSetTreeSet =
new TreeSet<>(ImmutableBitSet.ORDERING);
final Map<ImmutableBitSet, Integer> groupSetToDistinctAggCallFilterArg = new HashMap<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
if (!aggCall.isDistinct()) {
groupSetTreeSet.add(aggregate.getGroupSet());
} else {
ImmutableBitSet groupSet =
ImmutableBitSet.of(aggCall.getArgList())
.setIf(aggCall.filterArg, aggCall.filterArg >= 0)
.union(aggregate.getGroupSet());
groupSetToDistinctAggCallFilterArg.put(groupSet, aggCall.filterArg);
groupSetTreeSet.add(groupSet);
}
}
final com.google.common.collect.ImmutableList<ImmutableBitSet> groupSets =
com.google.common.collect.ImmutableList.copyOf(groupSetTreeSet);
final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
final List<AggregateCall> distinctAggCalls = new ArrayList<>();
for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
if (!aggCall.left.isDistinct()) {
AggregateCall newAggCall = aggCall.left.adaptTo(
aggregate.getInput(),
aggCall.left.getArgList(),
aggCall.left.filterArg,
aggregate.getGroupCount(),
fullGroupSet.cardinality());
distinctAggCalls.add(newAggCall.rename(aggCall.right));
}
}
final RelBuilder relBuilder = call.builder();
relBuilder.push(aggregate.getInput());
final int groupCount = fullGroupSet.cardinality();
final Map<ImmutableBitSet, Integer> filters = new LinkedHashMap<>();
final int z = groupCount + distinctAggCalls.size();
distinctAggCalls.add(
AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false,
false, ImmutableIntList.copyOf(fullGroupSet), -1,
RelCollations.EMPTY, groupSets.size(),
relBuilder.peek(), null, "$g"));
for (Ord<ImmutableBitSet> groupSet : Ord.zip(groupSets)) {
filters.put(groupSet.e, z + groupSet.i);
}
relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets),
distinctAggCalls);
final RelNode distinct = relBuilder.peek();
// GROUPING returns an integer (0 or 1). Add a project to convert those
// values to BOOLEAN.
if (!filters.isEmpty()) {
final List<RexNode> nodes = new ArrayList<>(relBuilder.fields());
final RexNode nodeZ = nodes.remove(nodes.size() - 1);
for (Map.Entry<ImmutableBitSet, Integer> entry : filters.entrySet()) {
final long v = groupValue(fullGroupSet, entry.getKey());
// Get and remap the filterArg of the distinct aggregate call.
int distinctAggCallFilterArg = remap(fullGroupSet,
groupSetToDistinctAggCallFilterArg.getOrDefault(entry.getKey(), -1));
RexNode expr;
if (distinctAggCallFilterArg < 0) {
expr = relBuilder.equals(nodeZ, relBuilder.literal(v));
} else {
RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
// merge the filter of the distinct aggregate call itself.
expr = relBuilder.and(
relBuilder.equals(nodeZ, relBuilder.literal(v)),
rexBuilder.makeCall(SqlStdOperatorTable.IS_TRUE,
relBuilder.field(distinctAggCallFilterArg)));
}
nodes.add(relBuilder.alias(expr, "$g_" + v));
}
relBuilder.project(nodes);
}
int aggCallIdx = 0;
int x = groupCount;
final List<AggregateCall> newCalls = new ArrayList<>();
// TODO supports more aggCalls (currently only supports COUNT)
// Some aggregate functions (e.g. COUNT) have the special property that they can return a
// non-null result without any input. We need to make sure we return a result in this case.
final List<Integer> needDefaultValueAggCalls = new ArrayList<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
final int newFilterArg;
final List<Integer> newArgList;
final SqlAggFunction aggregation;
if (!aggCall.isDistinct()) {
aggregation = SqlStdOperatorTable.MIN;
newArgList = ImmutableIntList.of(x++);
newFilterArg = filters.get(aggregate.getGroupSet());
switch (aggCall.getAggregation().getKind()) {
case COUNT:
needDefaultValueAggCalls.add(aggCallIdx);
break;
default:
}
} else {
aggregation = aggCall.getAggregation();
newArgList = remap(fullGroupSet, aggCall.getArgList());
newFilterArg =
filters.get(
ImmutableBitSet.of(aggCall.getArgList())
.setIf(aggCall.filterArg, aggCall.filterArg >= 0)
.union(aggregate.getGroupSet()));
}
final AggregateCall newCall =
AggregateCall.create(aggregation, false, aggCall.isApproximate(),
false, newArgList, newFilterArg,
RelCollations.EMPTY, aggregate.getGroupCount(), distinct,
null, aggCall.name);
newCalls.add(newCall);
aggCallIdx++;
}
relBuilder.aggregate(
relBuilder.groupKey(
remap(fullGroupSet, aggregate.getGroupSet()),
remap(fullGroupSet, aggregate.getGroupSets())),
newCalls);
if (!needDefaultValueAggCalls.isEmpty() && aggregate.getGroupCount() == 0) {
final Aggregate newAgg = (Aggregate) relBuilder.peek();
final List<RexNode> nodes = new ArrayList<>();
for (int i = 0; i < newAgg.getGroupCount(); ++i) {
nodes.add(RexInputRef.of(i, newAgg.getRowType()));
}
for (int i = 0; i < newAgg.getAggCallList().size(); ++i) {
final RexNode inputRef = RexInputRef.of(newAgg.getGroupCount() + i, newAgg.getRowType());
RexNode newNode = inputRef;
if (needDefaultValueAggCalls.contains(i)) {
SqlKind originalFunKind = aggregate.getAggCallList().get(i).getAggregation().getKind();
switch (originalFunKind) {
case COUNT:
newNode = relBuilder.call(
SqlStdOperatorTable.CASE,
relBuilder.isNotNull(inputRef),
inputRef,
relBuilder.literal(BigDecimal.ZERO));
break;
default:
}
}
nodes.add(newNode);
}
relBuilder.project(nodes);
}
relBuilder.convert(aggregate.getRowType(), true);
call.transformTo(relBuilder.build());
}