def generateTableAggregations()

in flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala [843:1053]


  def generateTableAggregations(
    tableAggOutputRowType: RowTypeInfo,
    tableAggOutputType: TypeInformation[_],
    supportEmitIncrementally: Boolean): GeneratedAggregationsFunction = {

    // constants
    val CONVERT_COLLECTOR_CLASS_TERM = "ConvertCollector"
    val CONVERT_COLLECTOR_VARIABLE_TERM = "convertCollector"
    val COLLECTOR_VARIABLE_TERM = "cRowWrappingcollector"
    val CONVERTER_ROW_RESULT_TERM = "rowTerm"

    // emit methods
    val emitValue = "emitValue"
    val emitUpdateWithRetract = "emitUpdateWithRetract"

    // collectors
    val COLLECTOR: String = classOf[Collector[_]].getCanonicalName
    val CROW_WRAPPING_COLLECTOR: String = classOf[CRowWrappingCollector].getCanonicalName
    val RETRACTABLE_COLLECTOR: String =
      classOf[TableAggregateFunction.RetractableCollector[_]].getCanonicalName

    val ROW: String = classOf[Row].getCanonicalName

    // Set emitValue as the default emit method here and set it to emitUpdateWithRetract on
    // condition that: 1. emitUpdateWithRetract has been defined in the table aggregate
    // function and 2. the operator supports emit incrementally, for example, window flatAggregate
    // doesn't support emit incrementally now)
    var finalEmitMethodName: String = emitValue

    def genEmit: String = {

      val sig: String =
        j"""
           |  public final void emit(
           |    $ROW accs,
           |    $COLLECTOR<$ROW> collector) throws Exception """.stripMargin

      val emit: String = {
        for (i <- aggs.indices) yield {
          val emitAcc =
            j"""
               |      ${genAccDataViewFieldSetter(s"acc$i", i)}
               |      ${aggs(i)}.$finalEmitMethodName(acc$i
               |        ${if (!parametersCode(i).isEmpty) "," else ""}
               |        $CONVERT_COLLECTOR_VARIABLE_TERM);
             """.stripMargin
          j"""
             |    ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
             |    $CONVERT_COLLECTOR_VARIABLE_TERM.$COLLECTOR_VARIABLE_TERM =
             |      ($CROW_WRAPPING_COLLECTOR) collector;
             |    $emitAcc
           """.stripMargin
        }
      }.mkString("\n")

      j"""$sig {
         |$emit
         |}""".stripMargin
    }

    def genRecordToRow: String = {
      // gen access expr

      val functionGenerator = new FunctionCodeGenerator(
        config,
        false,
        tableAggOutputType,
        None,
        None,
        None)

      functionGenerator.outRecordTerm = s"$CONVERTER_ROW_RESULT_TERM"
      val resultExprs = functionGenerator.generateConverterResultExpression(
        tableAggOutputRowType, tableAggOutputRowType.getFieldNames)

      functionGenerator.reuseInputUnboxingCode() + resultExprs.code
    }

    def checkAndGetEmitValueMethod(function: UserDefinedFunction, index: Int): Unit = {
      finalEmitMethodName = emitValue
      getUserDefinedMethod(
        function, emitValue, Array(accTypeClasses(index), classOf[Collector[_]]))
        .getOrElse(throw new CodeGenException(
          s"No matching $emitValue method found for " +
            s"tableAggregate ${function.getClass.getCanonicalName}'."))
    }

    /**
      * Call super init and check emit methods.
      */
    def innerInit(): Unit = {
      init()
      // check and validate the emit methods. Find incremental emit method first if the operator
      // supports emit incrementally.
      aggregates.zipWithIndex.map {
        case (a, i) =>
          if (supportEmitIncrementally) {
            try {
              finalEmitMethodName = emitUpdateWithRetract
              getUserDefinedMethod(
                a,
                emitUpdateWithRetract,
                Array(accTypeClasses(i), classOf[TableAggregateFunction.RetractableCollector[_]]))
                .getOrElse(checkAndGetEmitValueMethod(a, i))
            } catch {
              case _: ValidationException =>
                // Use try catch here as exception will be thrown if there is no
                // emitUpdateWithRetract method
                checkAndGetEmitValueMethod(a, i)
            }
          } else {
            checkAndGetEmitValueMethod(a, i)
          }
      }
    }

    /**
      * Generates the retract method if it is a [[TableAggregateFunction.RetractableCollector]].
      */
    def getRetractMethodForConvertCollector(emitMethodName: String): String = {
      if (emitMethodName == emitValue) {
        // Users can't retract messages with emitValue method.
        j"""
           |
          """.stripMargin
      } else {
        // only generates retract method for RetractableCollector
        j"""
           |      @Override
           |      public void retract(Object record) throws Exception {
           |          $COLLECTOR_VARIABLE_TERM.setChange(false);
           |          $COLLECTOR_VARIABLE_TERM.collect(convertToRow(record));
           |          $COLLECTOR_VARIABLE_TERM.setChange(true);
           |      }
          """.stripMargin
      }
    }

    innerInit()
    val aggFuncCode = Seq(
      genAccumulate,
      genRetract,
      genCreateAccumulators,
      genCreateOutputRow,
      genSetForwardedFields,
      genMergeAccumulatorsPair,
      genEmit).mkString("\n")

    val generatedAggregationsClass = classOf[GeneratedTableAggregations].getCanonicalName
    val aggOutputTypeName = tableAggOutputType.getTypeClass.getCanonicalName

    val baseCollectorString =
      if (finalEmitMethodName == emitValue) COLLECTOR else RETRACTABLE_COLLECTOR

    val funcCode =
      j"""
         |public final class $funcName extends $generatedAggregationsClass {
         |
         |  private $CONVERT_COLLECTOR_CLASS_TERM $CONVERT_COLLECTOR_VARIABLE_TERM;
         |  ${reuseMemberCode()}
         |  $genMergeList
         |  public $funcName() throws Exception {
         |    ${reuseInitCode()}
         |    $CONVERT_COLLECTOR_VARIABLE_TERM = new $CONVERT_COLLECTOR_CLASS_TERM();
         |  }
         |  ${reuseConstructorCode(funcName)}
         |
         |  public final void open(
         |    org.apache.flink.api.common.functions.RuntimeContext $contextTerm) throws Exception {
         |    ${reuseOpenCode()}
         |  }
         |
         |  $aggFuncCode
         |
         |  public final void cleanup() throws Exception {
         |    ${reuseCleanupCode()}
         |  }
         |
         |  public final void close() throws Exception {
         |    ${reuseCloseCode()}
         |  }
         |
         |  private class $CONVERT_COLLECTOR_CLASS_TERM implements $baseCollectorString {
         |
         |      public $CROW_WRAPPING_COLLECTOR $COLLECTOR_VARIABLE_TERM;
         |      private final $ROW $CONVERTER_ROW_RESULT_TERM =
         |        new $ROW(${tableAggOutputType.getArity});
         |
         |      public $ROW convertToRow(Object record) throws Exception {
         |         $aggOutputTypeName in1 = ($aggOutputTypeName) record;
         |         $genRecordToRow
         |         return $CONVERTER_ROW_RESULT_TERM;
         |      }
         |
         |      @Override
         |      public void collect(Object record) throws Exception {
         |          $COLLECTOR_VARIABLE_TERM.collect(convertToRow(record));
         |      }
         |
         |      ${getRetractMethodForConvertCollector(finalEmitMethodName)}
         |
         |      @Override
         |      public void close() {
         |       $COLLECTOR_VARIABLE_TERM.close();
         |      }
         |  }
         |}
         """.stripMargin

    new GeneratedTableAggregationsFunction(funcName, funcCode, finalEmitMethodName != emitValue)
  }