public void onMatch()

in flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/plan/rules/logical/FlinkAggregateJoinTransposeRule.java [176:440]


	public void onMatch(RelOptRuleCall call) {
		final Aggregate origAgg = call.rel(0);
		final Join join = call.rel(1);
		final RexBuilder rexBuilder = origAgg.getCluster().getRexBuilder();
		final RelBuilder relBuilder = call.builder();

		// converts an aggregate with AUXILIARY_GROUP to a regular aggregate.
		// if the converted aggregate can be push down,
		// AggregateReduceGroupingRule will try reduce grouping of new aggregates created by this rule
		final Pair<Aggregate, List<RexNode>> newAggAndProject = toRegularAggregate(origAgg);
		final Aggregate aggregate = newAggAndProject.left;
		final List<RexNode> projectAfterAgg = newAggAndProject.right;

		// If any aggregate functions do not support splitting, bail out
		// If any aggregate call has a filter or distinct, bail out
		for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
			if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class)
					== null) {
				return;
			}
			if (aggregateCall.filterArg >= 0 || aggregateCall.isDistinct()) {
				return;
			}
		}

		if (join.getJoinType() != JoinRelType.INNER) {
			return;
		}

		if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
			return;
		}

		// Do the columns used by the join appear in the output of the aggregate?
		final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
		final RelMetadataQuery mq = call.getMetadataQuery();
		final ImmutableBitSet keyColumns = keyColumns(aggregateColumns,
				mq.getPulledUpPredicates(join).pulledUpPredicates);
		final ImmutableBitSet joinColumns =
				RelOptUtil.InputFinder.bits(join.getCondition());
		final boolean allColumnsInAggregate =
				keyColumns.contains(joinColumns);
		final ImmutableBitSet belowAggregateColumns =
				aggregateColumns.union(joinColumns);

		// Split join condition
		final List<Integer> leftKeys = com.google.common.collect.Lists.newArrayList();
		final List<Integer> rightKeys = com.google.common.collect.Lists.newArrayList();
		final List<Boolean> filterNulls = com.google.common.collect.Lists.newArrayList();
		RexNode nonEquiConj =
				RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(),
						join.getCondition(), leftKeys, rightKeys, filterNulls);
		// If it contains non-equi join conditions, we bail out
		if (!nonEquiConj.isAlwaysTrue()) {
			return;
		}

		// Push each aggregate function down to each side that contains all of its
		// arguments. Note that COUNT(*), because it has no arguments, can go to
		// both sides.
		final Map<Integer, Integer> map = new HashMap<>();
		final List<Side> sides = new ArrayList<>();
		int uniqueCount = 0;
		int offset = 0;
		int belowOffset = 0;
		for (int s = 0; s < 2; s++) {
			final Side side = new Side();
			final RelNode joinInput = join.getInput(s);
			int fieldCount = joinInput.getRowType().getFieldCount();
			final ImmutableBitSet fieldSet =
					ImmutableBitSet.range(offset, offset + fieldCount);
			final ImmutableBitSet belowAggregateKeyNotShifted =
					belowAggregateColumns.intersect(fieldSet);
			for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
				map.put(c.e, belowOffset + c.i);
			}
			final Mappings.TargetMapping mapping =
					s == 0
							? Mappings.createIdentity(fieldCount)
							: Mappings.createShiftMapping(fieldCount + offset, 0, offset,
							fieldCount);

			final ImmutableBitSet belowAggregateKey =
					belowAggregateKeyNotShifted.shift(-offset);
			final boolean unique;
			if (!allowFunctions) {
				assert aggregate.getAggCallList().isEmpty();
				// If there are no functions, it doesn't matter as much whether we
				// aggregate the inputs before the join, because there will not be
				// any functions experiencing a cartesian product effect.
				//
				// But finding out whether the input is already unique requires a call
				// to areColumnsUnique that currently (until [CALCITE-1048] "Make
				// metadata more robust" is fixed) places a heavy load on
				// the metadata system.
				//
				// So we choose to imagine the the input is already unique, which is
				// untrue but harmless.
				//
				Util.discard(Bug.CALCITE_1048_FIXED);
				unique = true;
			} else {
				final Boolean unique0 =
						mq.areColumnsUnique(joinInput, belowAggregateKey);
				unique = unique0 != null && unique0;
			}
			if (unique) {
				++uniqueCount;
				side.aggregate = false;
				relBuilder.push(joinInput);
				final Map<Integer, Integer> belowAggregateKeyToNewProjectMap = new HashMap<>();
				final List<RexNode> projects = new ArrayList<>();
				for (Integer i : belowAggregateKey) {
					belowAggregateKeyToNewProjectMap.put(i, projects.size());
					projects.add(relBuilder.field(i));
				}
				for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
					final SqlAggFunction aggregation = aggCall.e.getAggregation();
					final SqlSplittableAggFunction splitter =
							Preconditions.checkNotNull(
									aggregation.unwrap(SqlSplittableAggFunction.class));
					if (!aggCall.e.getArgList().isEmpty()
							&& fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
						final RexNode singleton = splitter.singleton(rexBuilder,
								joinInput.getRowType(), aggCall.e.transform(mapping));
						final RexNode targetSingleton = rexBuilder.ensureType(aggCall.e.type, singleton, false);

						if (targetSingleton instanceof RexInputRef) {
							final int index = ((RexInputRef) targetSingleton).getIndex();
							if (!belowAggregateKey.get(index)) {
								projects.add(targetSingleton);
								side.split.put(aggCall.i, projects.size() - 1);
							} else {
								side.split.put(aggCall.i, belowAggregateKeyToNewProjectMap.get(index));
							}
						} else {
							projects.add(targetSingleton);
							side.split.put(aggCall.i, projects.size() - 1);
						}
					}
				}
				relBuilder.project(projects);
				side.newInput = relBuilder.build();
			} else {
				side.aggregate = true;
				List<AggregateCall> belowAggCalls = new ArrayList<>();
				final SqlSplittableAggFunction.Registry<AggregateCall>
						belowAggCallRegistry = registry(belowAggCalls);
				final int oldGroupKeyCount = aggregate.getGroupCount();
				final int newGroupKeyCount = belowAggregateKey.cardinality();
				for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
					final SqlAggFunction aggregation = aggCall.e.getAggregation();
					final SqlSplittableAggFunction splitter =
							Preconditions.checkNotNull(
									aggregation.unwrap(SqlSplittableAggFunction.class));
					final AggregateCall call1;
					if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
						final AggregateCall splitCall = splitter.split(aggCall.e, mapping);
						call1 = splitCall.adaptTo(
								joinInput, splitCall.getArgList(), splitCall.filterArg,
								oldGroupKeyCount, newGroupKeyCount);
					} else {
						call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
					}
					if (call1 != null) {
						side.split.put(aggCall.i,
								belowAggregateKey.cardinality()
										+ belowAggCallRegistry.register(call1));
					}
				}
				side.newInput = relBuilder.push(joinInput)
						.aggregate(relBuilder.groupKey(belowAggregateKey, null),
								belowAggCalls)
						.build();
			}
			offset += fieldCount;
			belowOffset += side.newInput.getRowType().getFieldCount();
			sides.add(side);
		}

		if (uniqueCount == 2) {
			// Both inputs to the join are unique. There is nothing to be gained by
			// this rule. In fact, this aggregate+join may be the result of a previous
			// invocation of this rule; if we continue we might loop forever.
			return;
		}

		// Update condition
		final Mapping mapping = (Mapping) Mappings.target(
				map::get,
				join.getRowType().getFieldCount(),
				belowOffset);
		final RexNode newCondition =
				RexUtil.apply(mapping, join.getCondition());

		// Create new join
		relBuilder.push(sides.get(0).newInput)
				.push(sides.get(1).newInput)
				.join(join.getJoinType(), newCondition);

		// Aggregate above to sum up the sub-totals
		final List<AggregateCall> newAggCalls = new ArrayList<>();
		final int groupIndicatorCount =
				aggregate.getGroupCount() + aggregate.getIndicatorCount();
		final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
		final List<RexNode> projects =
				new ArrayList<>(
						rexBuilder.identityProjects(relBuilder.peek().getRowType()));
		for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
			final SqlAggFunction aggregation = aggCall.e.getAggregation();
			final SqlSplittableAggFunction splitter =
					Preconditions.checkNotNull(
							aggregation.unwrap(SqlSplittableAggFunction.class));
			final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
			final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
			newAggCalls.add(
					splitter.topSplit(rexBuilder, registry(projects),
							groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e,
							leftSubTotal == null ? -1 : leftSubTotal,
							rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
		}

		relBuilder.project(projects);

		boolean aggConvertedToProjects = false;
		if (allColumnsInAggregate) {
			// let's see if we can convert aggregate into projects
			List<RexNode> projects2 = new ArrayList<>();
			for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
				projects2.add(relBuilder.field(key));
			}
			int aggCallIdx = projects2.size();
			for (AggregateCall newAggCall : newAggCalls) {
				final SqlSplittableAggFunction splitter =
						newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
				if (splitter != null) {
					final RelDataType rowType = relBuilder.peek().getRowType();
					final RexNode singleton = splitter.singleton(rexBuilder, rowType, newAggCall);
					final RelDataType originalAggCallType =
							aggregate.getRowType().getFieldList().get(aggCallIdx).getType();
					final RexNode targetSingleton = rexBuilder.ensureType(originalAggCallType, singleton, false);
					projects2.add(targetSingleton);
				}
				aggCallIdx += 1;
			}
			if (projects2.size()
					== aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
				// We successfully converted agg calls into projects.
				relBuilder.project(projects2);
				aggConvertedToProjects = true;
			}
		}

		if (!aggConvertedToProjects) {
			relBuilder.aggregate(
					relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()),
							Mappings.apply2(mapping, aggregate.getGroupSets())),
					newAggCalls);
		}
		if (projectAfterAgg != null) {
			relBuilder.project(projectAfterAgg, origAgg.getRowType().getFieldNames());
		}

		call.transformTo(relBuilder.build());
	}