in flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala [843:1053]
def generateTableAggregations(
tableAggOutputRowType: RowTypeInfo,
tableAggOutputType: TypeInformation[_],
supportEmitIncrementally: Boolean): GeneratedAggregationsFunction = {
// constants
val CONVERT_COLLECTOR_CLASS_TERM = "ConvertCollector"
val CONVERT_COLLECTOR_VARIABLE_TERM = "convertCollector"
val COLLECTOR_VARIABLE_TERM = "cRowWrappingcollector"
val CONVERTER_ROW_RESULT_TERM = "rowTerm"
// emit methods
val emitValue = "emitValue"
val emitUpdateWithRetract = "emitUpdateWithRetract"
// collectors
val COLLECTOR: String = classOf[Collector[_]].getCanonicalName
val CROW_WRAPPING_COLLECTOR: String = classOf[CRowWrappingCollector].getCanonicalName
val RETRACTABLE_COLLECTOR: String =
classOf[TableAggregateFunction.RetractableCollector[_]].getCanonicalName
val ROW: String = classOf[Row].getCanonicalName
// Set emitValue as the default emit method here and set it to emitUpdateWithRetract on
// condition that: 1. emitUpdateWithRetract has been defined in the table aggregate
// function and 2. the operator supports emit incrementally, for example, window flatAggregate
// doesn't support emit incrementally now)
var finalEmitMethodName: String = emitValue
def genEmit: String = {
val sig: String =
j"""
| public final void emit(
| $ROW accs,
| $COLLECTOR<$ROW> collector) throws Exception """.stripMargin
val emit: String = {
for (i <- aggs.indices) yield {
val emitAcc =
j"""
| ${genAccDataViewFieldSetter(s"acc$i", i)}
| ${aggs(i)}.$finalEmitMethodName(acc$i
| ${if (!parametersCode(i).isEmpty) "," else ""}
| $CONVERT_COLLECTOR_VARIABLE_TERM);
""".stripMargin
j"""
| ${accTypes(i)} acc$i = (${accTypes(i)}) accs.getField($i);
| $CONVERT_COLLECTOR_VARIABLE_TERM.$COLLECTOR_VARIABLE_TERM =
| ($CROW_WRAPPING_COLLECTOR) collector;
| $emitAcc
""".stripMargin
}
}.mkString("\n")
j"""$sig {
|$emit
|}""".stripMargin
}
def genRecordToRow: String = {
// gen access expr
val functionGenerator = new FunctionCodeGenerator(
config,
false,
tableAggOutputType,
None,
None,
None)
functionGenerator.outRecordTerm = s"$CONVERTER_ROW_RESULT_TERM"
val resultExprs = functionGenerator.generateConverterResultExpression(
tableAggOutputRowType, tableAggOutputRowType.getFieldNames)
functionGenerator.reuseInputUnboxingCode() + resultExprs.code
}
def checkAndGetEmitValueMethod(function: UserDefinedFunction, index: Int): Unit = {
finalEmitMethodName = emitValue
getUserDefinedMethod(
function, emitValue, Array(accTypeClasses(index), classOf[Collector[_]]))
.getOrElse(throw new CodeGenException(
s"No matching $emitValue method found for " +
s"tableAggregate ${function.getClass.getCanonicalName}'."))
}
/**
* Call super init and check emit methods.
*/
def innerInit(): Unit = {
init()
// check and validate the emit methods. Find incremental emit method first if the operator
// supports emit incrementally.
aggregates.zipWithIndex.map {
case (a, i) =>
if (supportEmitIncrementally) {
try {
finalEmitMethodName = emitUpdateWithRetract
getUserDefinedMethod(
a,
emitUpdateWithRetract,
Array(accTypeClasses(i), classOf[TableAggregateFunction.RetractableCollector[_]]))
.getOrElse(checkAndGetEmitValueMethod(a, i))
} catch {
case _: ValidationException =>
// Use try catch here as exception will be thrown if there is no
// emitUpdateWithRetract method
checkAndGetEmitValueMethod(a, i)
}
} else {
checkAndGetEmitValueMethod(a, i)
}
}
}
/**
* Generates the retract method if it is a [[TableAggregateFunction.RetractableCollector]].
*/
def getRetractMethodForConvertCollector(emitMethodName: String): String = {
if (emitMethodName == emitValue) {
// Users can't retract messages with emitValue method.
j"""
|
""".stripMargin
} else {
// only generates retract method for RetractableCollector
j"""
| @Override
| public void retract(Object record) throws Exception {
| $COLLECTOR_VARIABLE_TERM.setChange(false);
| $COLLECTOR_VARIABLE_TERM.collect(convertToRow(record));
| $COLLECTOR_VARIABLE_TERM.setChange(true);
| }
""".stripMargin
}
}
innerInit()
val aggFuncCode = Seq(
genAccumulate,
genRetract,
genCreateAccumulators,
genCreateOutputRow,
genSetForwardedFields,
genMergeAccumulatorsPair,
genEmit).mkString("\n")
val generatedAggregationsClass = classOf[GeneratedTableAggregations].getCanonicalName
val aggOutputTypeName = tableAggOutputType.getTypeClass.getCanonicalName
val baseCollectorString =
if (finalEmitMethodName == emitValue) COLLECTOR else RETRACTABLE_COLLECTOR
val funcCode =
j"""
|public final class $funcName extends $generatedAggregationsClass {
|
| private $CONVERT_COLLECTOR_CLASS_TERM $CONVERT_COLLECTOR_VARIABLE_TERM;
| ${reuseMemberCode()}
| $genMergeList
| public $funcName() throws Exception {
| ${reuseInitCode()}
| $CONVERT_COLLECTOR_VARIABLE_TERM = new $CONVERT_COLLECTOR_CLASS_TERM();
| }
| ${reuseConstructorCode(funcName)}
|
| public final void open(
| org.apache.flink.api.common.functions.RuntimeContext $contextTerm) throws Exception {
| ${reuseOpenCode()}
| }
|
| $aggFuncCode
|
| public final void cleanup() throws Exception {
| ${reuseCleanupCode()}
| }
|
| public final void close() throws Exception {
| ${reuseCloseCode()}
| }
|
| private class $CONVERT_COLLECTOR_CLASS_TERM implements $baseCollectorString {
|
| public $CROW_WRAPPING_COLLECTOR $COLLECTOR_VARIABLE_TERM;
| private final $ROW $CONVERTER_ROW_RESULT_TERM =
| new $ROW(${tableAggOutputType.getArity});
|
| public $ROW convertToRow(Object record) throws Exception {
| $aggOutputTypeName in1 = ($aggOutputTypeName) record;
| $genRecordToRow
| return $CONVERTER_ROW_RESULT_TERM;
| }
|
| @Override
| public void collect(Object record) throws Exception {
| $COLLECTOR_VARIABLE_TERM.collect(convertToRow(record));
| }
|
| ${getRetractMethodForConvertCollector(finalEmitMethodName)}
|
| @Override
| public void close() {
| $COLLECTOR_VARIABLE_TERM.close();
| }
| }
|}
""".stripMargin
new GeneratedTableAggregationsFunction(funcName, funcCode, finalEmitMethodName != emitValue)
}