override def onMatch()

in flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitAggregateRule.scala [126:334]


  override def onMatch(call: RelOptRuleCall): Unit = {
    val tableConfig = call.getPlanner.getContext.unwrap(classOf[FlinkContext]).getTableConfig
    val originalAggregate: FlinkLogicalAggregate = call.rel(0)
    val aggCalls = originalAggregate.getAggCallList
    val input: FlinkRelNode = call.rel(1)
    val cluster = originalAggregate.getCluster
    val relBuilder = call.builder().asInstanceOf[FlinkRelBuilder]
    relBuilder.push(input)
    val aggGroupSet = originalAggregate.getGroupSet.toArray

    // STEP 1: add hash fields if necessary
    val hashFieldIndexes: Array[Int] = aggCalls.flatMap { aggCall =>
      if (SplitAggregateRule.needAddHashFields(aggCall)) {
        SplitAggregateRule.getArgIndexes(aggCall)
      } else {
        Array.empty[Int]
      }
    }.distinct.diff(aggGroupSet).sorted.toArray

    val hashFieldsMap: util.Map[Int, Int] = new util.HashMap()
    val buckets = tableConfig.getConfiguration.getInteger(
      OptimizerConfigOptions.TABLE_OPTIMIZER_DISTINCT_AGG_SPLIT_BUCKET_NUM)

    if (hashFieldIndexes.nonEmpty) {
      val projects = new util.ArrayList[RexNode](relBuilder.fields)
      val hashFieldsOffset = projects.size()

      hashFieldIndexes.zipWithIndex.foreach { case (hashFieldIdx, index) =>
        val hashField = relBuilder.field(hashFieldIdx)
        // hash(f) % buckets
        val node: RexNode = relBuilder.call(
          SqlStdOperatorTable.MOD,
          relBuilder.call(FlinkSqlOperatorTable.HASH_CODE, hashField),
          relBuilder.literal(buckets))
        projects.add(node)
        hashFieldsMap.put(hashFieldIdx, hashFieldsOffset + index)
      }
      relBuilder.project(projects)
    }

    // STEP 2: construct partial aggregates
    val groupSetTreeSet = new util.TreeSet[ImmutableBitSet](ImmutableBitSet.ORDERING)
    val aggInfoToGroupSetMap = new util.HashMap[AggregateCall, ImmutableBitSet]()
    aggCalls.foreach { aggCall =>
      val groupSet = if (SplitAggregateRule.needAddHashFields(aggCall)) {
        val newIndexes = SplitAggregateRule.getArgIndexes(aggCall).map { argIndex =>
          hashFieldsMap.getOrElse(argIndex, argIndex).asInstanceOf[Integer]
        }.toSeq
        ImmutableBitSet.of(newIndexes).union(ImmutableBitSet.of(aggGroupSet: _*))
      } else {
        ImmutableBitSet.of(aggGroupSet: _*)
      }
      groupSetTreeSet.add(groupSet)
      aggInfoToGroupSetMap.put(aggCall, groupSet)
    }
    val groupSets = ImmutableList.copyOf(asJavaIterable(groupSetTreeSet))
    val fullGroupSet = ImmutableBitSet.union(groupSets)

    // STEP 2.1: expand input fields
    val partialAggCalls = new util.ArrayList[AggregateCall]
    val partialAggCallToGroupSetMap = new util.HashMap[AggregateCall, ImmutableBitSet]()
    aggCalls.foreach { aggCall =>
      val newAggCalls = SplitAggregateRule.getPartialAggFunction(aggCall).map { aggFunc =>
        AggregateCall.create(aggFunc, aggCall.isDistinct, aggCall.isApproximate, aggCall.getArgList,
          aggCall.filterArg, fullGroupSet.cardinality, relBuilder.peek(), null, null)
      }
      partialAggCalls.addAll(newAggCalls)
      newAggCalls.foreach { newAggCall =>
        partialAggCallToGroupSetMap.put(newAggCall, aggInfoToGroupSetMap.get(aggCall))
      }
    }

    val needExpand = groupSets.size() > 1
    val duplicateFieldMap = if (needExpand) {
      val (duplicateFieldMap, _) = ExpandUtil.buildExpandNode(
        cluster, relBuilder, partialAggCalls, fullGroupSet, groupSets)
      duplicateFieldMap
    } else {
      Map.empty[Integer, Integer]
    }

    // STEP 2.2: add filter columns for partial aggregates
    val filters = new util.LinkedHashMap[(ImmutableBitSet, Integer), Integer]
    val newPartialAggCalls = new util.ArrayList[AggregateCall]
    if (needExpand) {
      // GROUPING returns an integer (0, 1, 2...).
      // Add a project to convert those values to BOOLEAN.
      val nodes = new util.ArrayList[RexNode](relBuilder.fields)
      val expandIdNode = nodes.remove(nodes.size - 1)
      val filterColumnsOffset: Int = nodes.size
      var x: Int = 0
      partialAggCalls.foreach { aggCall =>
        val groupSet = partialAggCallToGroupSetMap.get(aggCall)
        val oldFilterArg = aggCall.filterArg
        val newArgList = aggCall.getArgList.map(a => duplicateFieldMap.getOrElse(a, a)).toList

        if (!filters.contains(groupSet, oldFilterArg)) {
          val expandId = ExpandUtil.genExpandId(fullGroupSet, groupSet)
          if (oldFilterArg >= 0) {
            nodes.add(relBuilder.alias(
              relBuilder.and(
                relBuilder.equals(expandIdNode, relBuilder.literal(expandId)),
                relBuilder.field(oldFilterArg)),
              "$g_" + expandId))
          } else {
            nodes.add(relBuilder.alias(
              relBuilder.equals(expandIdNode, relBuilder.literal(expandId)), "$g_" + expandId))
          }
          val newFilterArg = filterColumnsOffset + x
          filters.put((groupSet, oldFilterArg), newFilterArg)
          x += 1
        }

        val newFilterArg = filters((groupSet, oldFilterArg))
        val newAggCall = aggCall.adaptTo(
          relBuilder.peek(), newArgList, newFilterArg,
          fullGroupSet.cardinality, fullGroupSet.cardinality)
        newPartialAggCalls.add(newAggCall)
      }
      relBuilder.project(nodes)
    } else {
      newPartialAggCalls.addAll(partialAggCalls)
    }

    // STEP 2.3: construct partial aggregates
    relBuilder.aggregate(
      relBuilder.groupKey(fullGroupSet, ImmutableList.of[ImmutableBitSet](fullGroupSet)),
      newPartialAggCalls)
    relBuilder.peek().asInstanceOf[FlinkLogicalAggregate]
      .setPartialFinalType(PartialFinalType.PARTIAL)

    // STEP 3: construct final aggregates
    val finalAggInputOffset = fullGroupSet.cardinality
    var x: Int = 0
    val finalAggCalls = new util.ArrayList[AggregateCall]
    var needMergeFinalAggOutput: Boolean = false
    aggCalls.foreach { aggCall =>
      val newAggCalls = SplitAggregateRule.getFinalAggFunction(aggCall).map { aggFunction =>
        val newArgList = ImmutableIntList.of(finalAggInputOffset + x)
        x += 1

        AggregateCall.create(
          aggFunction, false, aggCall.isApproximate, newArgList, -1,
          originalAggregate.getGroupCount, relBuilder.peek(), null, null)
      }

      finalAggCalls.addAll(newAggCalls)
      if (newAggCalls.size > 1) {
        needMergeFinalAggOutput = true
      }
    }
    relBuilder.aggregate(
      relBuilder.groupKey(
        SplitAggregateRule.remap(fullGroupSet, originalAggregate.getGroupSet),
        SplitAggregateRule.remap(fullGroupSet, Seq(originalAggregate.getGroupSet))),
      finalAggCalls)
    val finalAggregate = relBuilder.peek().asInstanceOf[FlinkLogicalAggregate]
    finalAggregate.setPartialFinalType(PartialFinalType.FINAL)

    // STEP 4: convert final aggregation output to the original aggregation output.
    // For example, aggregate function AVG is transformed to SUM0 and COUNT, so the output of
    // the final aggregation is (sum, count). We should converted it to (sum / count)
    // for the final output.
    val aggGroupCount = finalAggregate.getGroupCount
    if (needMergeFinalAggOutput) {
      val nodes = new util.ArrayList[RexNode]
      (0 until aggGroupCount).foreach { index =>
        nodes.add(RexInputRef.of(index, finalAggregate.getRowType))
      }

      var avgAggCount: Int = 0
      aggCalls.zipWithIndex.foreach { case (aggCall, index) =>
        val newNode = if (aggCall.getAggregation.getKind == SqlKind.AVG) {
          val sumInputRef = RexInputRef.of(
            aggGroupCount + index + avgAggCount,
            finalAggregate.getRowType)
          val countInputRef = RexInputRef.of(
            aggGroupCount + index + avgAggCount + 1,
            finalAggregate.getRowType)
          avgAggCount += 1
          // Make a guarantee that the final aggregation returns NULL if underlying count is ZERO.
          // We use SUM0 for underlying sum, which may run into ZERO / ZERO,
          // and division by zero exception occurs.
          // @see Glossary#SQL2011 SQL:2011 Part 2 Section 6.27
          val equals = relBuilder.call(
            FlinkSqlOperatorTable.EQUALS,
            countInputRef,
            relBuilder.getRexBuilder.makeBigintLiteral(JBigDecimal.valueOf(0)))
          val ifTrue = relBuilder.cast(
            relBuilder.getRexBuilder.constantNull(), aggCall.`type`.getSqlTypeName)
          val ifFalse = relBuilder.call(FlinkSqlOperatorTable.DIVIDE, sumInputRef, countInputRef)
          relBuilder.call(
            FlinkSqlOperatorTable.IF,
            equals,
            ifTrue,
            ifFalse)
        } else {
          RexInputRef.of(aggGroupCount + index + avgAggCount, finalAggregate.getRowType)
        }
        nodes.add(newNode)
      }
      relBuilder.project(nodes)
    }

    relBuilder.convert(originalAggregate.getRowType, false)

    val newRel = relBuilder.build()
    call.transformTo(newRel)
  }