private void rewriteUsingGroupingSets()

in flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkAggregateExpandDistinctAggregatesRule.java [416:569]


	private void rewriteUsingGroupingSets(RelOptRuleCall call,
			Aggregate aggregate) {
		final Set<ImmutableBitSet> groupSetTreeSet =
				new TreeSet<>(ImmutableBitSet.ORDERING);
		final Map<ImmutableBitSet, Integer> groupSetToDistinctAggCallFilterArg = new HashMap<>();
		for (AggregateCall aggCall : aggregate.getAggCallList()) {
			if (!aggCall.isDistinct()) {
				groupSetTreeSet.add(aggregate.getGroupSet());
			} else {
				ImmutableBitSet groupSet =
						ImmutableBitSet.of(aggCall.getArgList())
								.setIf(aggCall.filterArg, aggCall.filterArg >= 0)
								.union(aggregate.getGroupSet());
				groupSetToDistinctAggCallFilterArg.put(groupSet, aggCall.filterArg);
				groupSetTreeSet.add(groupSet);
			}
		}

		final com.google.common.collect.ImmutableList<ImmutableBitSet> groupSets =
				com.google.common.collect.ImmutableList.copyOf(groupSetTreeSet);
		final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);

		final List<AggregateCall> distinctAggCalls = new ArrayList<>();
		for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
			if (!aggCall.left.isDistinct()) {
				AggregateCall newAggCall = aggCall.left.adaptTo(
						aggregate.getInput(),
						aggCall.left.getArgList(),
						aggCall.left.filterArg,
						aggregate.getGroupCount(),
						fullGroupSet.cardinality());
				distinctAggCalls.add(newAggCall.rename(aggCall.right));
			}
		}

		final RelBuilder relBuilder = call.builder();
		relBuilder.push(aggregate.getInput());
		final int groupCount = fullGroupSet.cardinality();

		final Map<ImmutableBitSet, Integer> filters = new LinkedHashMap<>();
		final int z = groupCount + distinctAggCalls.size();
		distinctAggCalls.add(
				AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false,
						false, ImmutableIntList.copyOf(fullGroupSet), -1,
						RelCollations.EMPTY, groupSets.size(),
						relBuilder.peek(), null, "$g"));
		for (Ord<ImmutableBitSet> groupSet : Ord.zip(groupSets)) {
			filters.put(groupSet.e, z + groupSet.i);
		}

		relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets),
				distinctAggCalls);
		final RelNode distinct = relBuilder.peek();

		// GROUPING returns an integer (0 or 1). Add a project to convert those
		// values to BOOLEAN.
		if (!filters.isEmpty()) {
			final List<RexNode> nodes = new ArrayList<>(relBuilder.fields());
			final RexNode nodeZ = nodes.remove(nodes.size() - 1);
			for (Map.Entry<ImmutableBitSet, Integer> entry : filters.entrySet()) {
				final long v = groupValue(fullGroupSet, entry.getKey());
				// Get and remap the filterArg of the distinct aggregate call.
				int distinctAggCallFilterArg = remap(fullGroupSet,
					groupSetToDistinctAggCallFilterArg.getOrDefault(entry.getKey(), -1));
				RexNode expr;
				if (distinctAggCallFilterArg < 0) {
					expr = relBuilder.equals(nodeZ, relBuilder.literal(v));
				} else {
					RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
					// merge the filter of the distinct aggregate call itself.
					expr = relBuilder.and(
						relBuilder.equals(nodeZ, relBuilder.literal(v)),
						rexBuilder.makeCall(SqlStdOperatorTable.IS_TRUE,
							relBuilder.field(distinctAggCallFilterArg)));
				}
				nodes.add(relBuilder.alias(expr, "$g_" + v));
			}
			relBuilder.project(nodes);
		}

		int aggCallIdx = 0;
		int x = groupCount;
		final List<AggregateCall> newCalls = new ArrayList<>();
		// TODO supports more aggCalls (currently only supports COUNT)
		// Some aggregate functions (e.g. COUNT) have the special property that they can return a
		// non-null result without any input. We need to make sure we return a result in this case.
		final List<Integer> needDefaultValueAggCalls = new ArrayList<>();
		for (AggregateCall aggCall : aggregate.getAggCallList()) {
			final int newFilterArg;
			final List<Integer> newArgList;
			final SqlAggFunction aggregation;
			if (!aggCall.isDistinct()) {
				aggregation = SqlStdOperatorTable.MIN;
				newArgList = ImmutableIntList.of(x++);
				newFilterArg = filters.get(aggregate.getGroupSet());
				switch (aggCall.getAggregation().getKind()) {
					case COUNT:
						needDefaultValueAggCalls.add(aggCallIdx);
						break;
					default:
				}
			} else {
				aggregation = aggCall.getAggregation();
				newArgList = remap(fullGroupSet, aggCall.getArgList());
				newFilterArg =
						filters.get(
								ImmutableBitSet.of(aggCall.getArgList())
										.setIf(aggCall.filterArg, aggCall.filterArg >= 0)
										.union(aggregate.getGroupSet()));
			}
			final AggregateCall newCall =
					AggregateCall.create(aggregation, false, aggCall.isApproximate(),
							false, newArgList, newFilterArg,
							RelCollations.EMPTY, aggregate.getGroupCount(), distinct,
							null, aggCall.name);
			newCalls.add(newCall);
			aggCallIdx++;
		}

		relBuilder.aggregate(
				relBuilder.groupKey(
						remap(fullGroupSet, aggregate.getGroupSet()),
						remap(fullGroupSet, aggregate.getGroupSets())),
				newCalls);
		if (!needDefaultValueAggCalls.isEmpty() && aggregate.getGroupCount() == 0) {
			final Aggregate newAgg = (Aggregate) relBuilder.peek();
			final List<RexNode> nodes = new ArrayList<>();
			for (int i = 0; i < newAgg.getGroupCount(); ++i) {
				nodes.add(RexInputRef.of(i, newAgg.getRowType()));
			}
			for (int i = 0; i < newAgg.getAggCallList().size(); ++i) {
				final RexNode inputRef = RexInputRef.of(newAgg.getGroupCount() + i, newAgg.getRowType());
				RexNode newNode = inputRef;
				if (needDefaultValueAggCalls.contains(i)) {
					SqlKind originalFunKind = aggregate.getAggCallList().get(i).getAggregation().getKind();
					switch (originalFunKind) {
						case COUNT:
							newNode = relBuilder.call(
									SqlStdOperatorTable.CASE,
									relBuilder.isNotNull(inputRef),
									inputRef,
									relBuilder.literal(BigDecimal.ZERO));
							break;
						default:
					}
				}
				nodes.add(newNode);
			}
			relBuilder.project(nodes);
		}

		relBuilder.convert(aggregate.getRowType(), true);
		call.transformTo(relBuilder.build());
	}