in flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRule.scala [126:334]
override def onMatch(call: RelOptRuleCall): Unit = {
val tableConfig = call.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig
val originalAggregate: FlinkLogicalAggregate = call.rel(0)
val aggCalls = originalAggregate.getAggCallList
val input: FlinkRelNode = call.rel(1)
val cluster = originalAggregate.getCluster
val relBuilder = call.builder().asInstanceOf[FlinkRelBuilder]
relBuilder.push(input)
val aggGroupSet = originalAggregate.getGroupSet.toArray
// STEP 1: add hash fields if necessary
val hashFieldIndexes: Array[Int] = aggCalls.flatMap { aggCall =>
if (SplitAggregateRule.needAddHashFields(aggCall)) {
SplitAggregateRule.getArgIndexes(aggCall)
} else {
Array.empty[Int]
}
}.distinct.diff(aggGroupSet).sorted.toArray
val hashFieldsMap: util.Map[Int, Int] = new util.HashMap()
val buckets = tableConfig.getConfiguration.getInteger(
OptimizerConfigOptions.TABLE_OPTIMIZER_DISTINCT_AGG_SPLIT_BUCKET_NUM)
if (hashFieldIndexes.nonEmpty) {
val projects = new util.ArrayList[RexNode](relBuilder.fields)
val hashFieldsOffset = projects.size()
hashFieldIndexes.zipWithIndex.foreach { case (hashFieldIdx, index) =>
val hashField = relBuilder.field(hashFieldIdx)
// hash(f) % buckets
val node: RexNode = relBuilder.call(
SqlStdOperatorTable.MOD,
relBuilder.call(FlinkSqlOperatorTable.HASH_CODE, hashField),
relBuilder.literal(buckets))
projects.add(node)
hashFieldsMap.put(hashFieldIdx, hashFieldsOffset + index)
}
relBuilder.project(projects)
}
// STEP 2: construct partial aggregates
val groupSetTreeSet = new util.TreeSet[ImmutableBitSet](ImmutableBitSet.ORDERING)
val aggInfoToGroupSetMap = new util.HashMap[AggregateCall, ImmutableBitSet]()
aggCalls.foreach { aggCall =>
val groupSet = if (SplitAggregateRule.needAddHashFields(aggCall)) {
val newIndexes = SplitAggregateRule.getArgIndexes(aggCall).map { argIndex =>
hashFieldsMap.getOrElse(argIndex, argIndex).asInstanceOf[Integer]
}.toSeq
ImmutableBitSet.of(newIndexes).union(ImmutableBitSet.of(aggGroupSet: _*))
} else {
ImmutableBitSet.of(aggGroupSet: _*)
}
groupSetTreeSet.add(groupSet)
aggInfoToGroupSetMap.put(aggCall, groupSet)
}
val groupSets = ImmutableList.copyOf(asJavaIterable(groupSetTreeSet))
val fullGroupSet = ImmutableBitSet.union(groupSets)
// STEP 2.1: expand input fields
val partialAggCalls = new util.ArrayList[AggregateCall]
val partialAggCallToGroupSetMap = new util.HashMap[AggregateCall, ImmutableBitSet]()
aggCalls.foreach { aggCall =>
val newAggCalls = SplitAggregateRule.getPartialAggFunction(aggCall).map { aggFunc =>
AggregateCall.create(aggFunc, aggCall.isDistinct, aggCall.isApproximate, aggCall.getArgList,
aggCall.filterArg, fullGroupSet.cardinality, relBuilder.peek(), null, null)
}
partialAggCalls.addAll(newAggCalls)
newAggCalls.foreach { newAggCall =>
partialAggCallToGroupSetMap.put(newAggCall, aggInfoToGroupSetMap.get(aggCall))
}
}
val needExpand = groupSets.size() > 1
val duplicateFieldMap = if (needExpand) {
val (duplicateFieldMap, _) = ExpandUtil.buildExpandNode(
cluster, relBuilder, partialAggCalls, fullGroupSet, groupSets)
duplicateFieldMap
} else {
Map.empty[Integer, Integer]
}
// STEP 2.2: add filter columns for partial aggregates
val filters = new util.LinkedHashMap[(ImmutableBitSet, Integer), Integer]
val newPartialAggCalls = new util.ArrayList[AggregateCall]
if (needExpand) {
// GROUPING returns an integer (0, 1, 2...).
// Add a project to convert those values to BOOLEAN.
val nodes = new util.ArrayList[RexNode](relBuilder.fields)
val expandIdNode = nodes.remove(nodes.size - 1)
val filterColumnsOffset: Int = nodes.size
var x: Int = 0
partialAggCalls.foreach { aggCall =>
val groupSet = partialAggCallToGroupSetMap.get(aggCall)
val oldFilterArg = aggCall.filterArg
val newArgList = aggCall.getArgList.map(a => duplicateFieldMap.getOrElse(a, a)).toList
if (!filters.contains(groupSet, oldFilterArg)) {
val expandId = ExpandUtil.genExpandId(fullGroupSet, groupSet)
if (oldFilterArg >= 0) {
nodes.add(relBuilder.alias(
relBuilder.and(
relBuilder.equals(expandIdNode, relBuilder.literal(expandId)),
relBuilder.field(oldFilterArg)),
"$g_" + expandId))
} else {
nodes.add(relBuilder.alias(
relBuilder.equals(expandIdNode, relBuilder.literal(expandId)), "$g_" + expandId))
}
val newFilterArg = filterColumnsOffset + x
filters.put((groupSet, oldFilterArg), newFilterArg)
x += 1
}
val newFilterArg = filters((groupSet, oldFilterArg))
val newAggCall = aggCall.adaptTo(
relBuilder.peek(), newArgList, newFilterArg,
fullGroupSet.cardinality, fullGroupSet.cardinality)
newPartialAggCalls.add(newAggCall)
}
relBuilder.project(nodes)
} else {
newPartialAggCalls.addAll(partialAggCalls)
}
// STEP 2.3: construct partial aggregates
relBuilder.aggregate(
relBuilder.groupKey(fullGroupSet, ImmutableList.of[ImmutableBitSet](fullGroupSet)),
newPartialAggCalls)
relBuilder.peek().asInstanceOf[FlinkLogicalAggregate]
.setPartialFinalType(PartialFinalType.PARTIAL)
// STEP 3: construct final aggregates
val finalAggInputOffset = fullGroupSet.cardinality
var x: Int = 0
val finalAggCalls = new util.ArrayList[AggregateCall]
var needMergeFinalAggOutput: Boolean = false
aggCalls.foreach { aggCall =>
val newAggCalls = SplitAggregateRule.getFinalAggFunction(aggCall).map { aggFunction =>
val newArgList = ImmutableIntList.of(finalAggInputOffset + x)
x += 1
AggregateCall.create(
aggFunction, false, aggCall.isApproximate, newArgList, -1,
originalAggregate.getGroupCount, relBuilder.peek(), null, null)
}
finalAggCalls.addAll(newAggCalls)
if (newAggCalls.size > 1) {
needMergeFinalAggOutput = true
}
}
relBuilder.aggregate(
relBuilder.groupKey(
SplitAggregateRule.remap(fullGroupSet, originalAggregate.getGroupSet),
SplitAggregateRule.remap(fullGroupSet, Seq(originalAggregate.getGroupSet))),
finalAggCalls)
val finalAggregate = relBuilder.peek().asInstanceOf[FlinkLogicalAggregate]
finalAggregate.setPartialFinalType(PartialFinalType.FINAL)
// STEP 4: convert final aggregation output to the original aggregation output.
// For example, aggregate function AVG is transformed to SUM0 and COUNT, so the output of
// the final aggregation is (sum, count). We should converted it to (sum / count)
// for the final output.
val aggGroupCount = finalAggregate.getGroupCount
if (needMergeFinalAggOutput) {
val nodes = new util.ArrayList[RexNode]
(0 until aggGroupCount).foreach { index =>
nodes.add(RexInputRef.of(index, finalAggregate.getRowType))
}
var avgAggCount: Int = 0
aggCalls.zipWithIndex.foreach { case (aggCall, index) =>
val newNode = if (aggCall.getAggregation.getKind == SqlKind.AVG) {
val sumInputRef = RexInputRef.of(
aggGroupCount + index + avgAggCount,
finalAggregate.getRowType)
val countInputRef = RexInputRef.of(
aggGroupCount + index + avgAggCount + 1,
finalAggregate.getRowType)
avgAggCount += 1
// Make a guarantee that the final aggregation returns NULL if underlying count is ZERO.
// We use SUM0 for underlying sum, which may run into ZERO / ZERO,
// and division by zero exception occurs.
// @see Glossary#SQL2011 SQL:2011 Part 2 Section 6.27
val equals = relBuilder.call(
FlinkSqlOperatorTable.EQUALS,
countInputRef,
relBuilder.getRexBuilder.makeBigintLiteral(JBigDecimal.valueOf(0)))
val ifTrue = relBuilder.cast(
relBuilder.getRexBuilder.constantNull(), aggCall.`type`.getSqlTypeName)
val ifFalse = relBuilder.call(FlinkSqlOperatorTable.DIVIDE, sumInputRef, countInputRef)
relBuilder.call(
FlinkSqlOperatorTable.IF,
equals,
ifTrue,
ifFalse)
} else {
RexInputRef.of(aggGroupCount + index + avgAggCount, finalAggregate.getRowType)
}
nodes.add(newNode)
}
relBuilder.project(nodes)
}
relBuilder.convert(originalAggregate.getRowType, false)
val newRel = relBuilder.build()
call.transformTo(newRel)
}