in flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala [66:226]
def genWithKeys(): GeneratedOperator[OneInputStreamOperator[RowData, RowData]] = {
val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
val className = if (isFinal) "HashAggregateWithKeys" else "LocalHashAggregateWithKeys"
// add logger
val logTerm = CodeGenUtils.newName("LOG")
ctx.addReusableLogger(logTerm, className)
// gen code to do group key projection from input
val currentKeyTerm = CodeGenUtils.newName("currentKey")
val currentKeyWriterTerm = CodeGenUtils.newName("currentKeyWriter")
val keyProjectionCode = ProjectionCodeGenerator.generateProjectionExpression(
ctx,
inputType,
groupKeyRowType,
grouping,
inputTerm = inputTerm,
outRecordTerm = currentKeyTerm,
outRecordWriterTerm = currentKeyWriterTerm).code
// gen code to create groupKey, aggBuffer Type array
// it will be used in BytesHashMap and BufferedKVExternalSorter if enable fallback
val groupKeyTypesTerm = CodeGenUtils.newName("groupKeyTypes")
val aggBufferTypesTerm = CodeGenUtils.newName("aggBufferTypes")
HashAggCodeGenHelper.prepareHashAggKVTypes(
ctx, groupKeyTypesTerm, aggBufferTypesTerm, groupKeyRowType, aggBufferRowType)
// gen code to aggregate and output using hash map
val aggregateMapTerm = CodeGenUtils.newName("aggregateMap")
val lookupInfo = ctx.addReusableLocalVariable(
classOf[BytesHashMap.LookupInfo].getCanonicalName,
"lookupInfo")
HashAggCodeGenHelper.prepareHashAggMap(
ctx,
groupKeyTypesTerm,
aggBufferTypesTerm,
aggregateMapTerm)
val outputTerm = CodeGenUtils.newName("hashAggOutput")
val (reuseAggMapEntryTerm, reuseGroupKeyTerm, reuseAggBufferTerm) =
HashAggCodeGenHelper.prepareTermForAggMapIteration(
ctx,
outputTerm,
outputType,
groupKeyRowType,
aggBufferRowType,
if (grouping.isEmpty) classOf[GenericRowData] else classOf[JoinedRowData])
val currentAggBufferTerm = ctx.addReusableLocalVariable(
classOf[BinaryRowData].getName, "currentAggBuffer")
val (initedAggBuffer, aggregate, outputExpr) = HashAggCodeGenHelper.genHashAggCodes(
isMerge,
isFinal,
ctx,
builder,
(grouping, auxGrouping),
inputTerm,
inputType,
aggInfos,
currentAggBufferTerm,
aggBufferRowType,
aggBufferTypes,
outputTerm,
outputType,
reuseGroupKeyTerm,
reuseAggBufferTerm)
val outputResultFromMap = HashAggCodeGenHelper.genAggMapIterationAndOutput(
ctx, isFinal, aggregateMapTerm, reuseAggMapEntryTerm, reuseAggBufferTerm, outputExpr)
// gen code to deal with hash map oom, if enable fallback we will use sort agg strategy
val sorterTerm = CodeGenUtils.newName("sorter")
val retryAppend = HashAggCodeGenHelper.genRetryAppendToMap(
aggregateMapTerm, currentKeyTerm, initedAggBuffer, lookupInfo, currentAggBufferTerm)
val (dealWithAggHashMapOOM, fallbackToSortAggCode) = HashAggCodeGenHelper.genAggMapOOMHandling(
isFinal,
ctx,
builder,
(grouping, auxGrouping),
aggInfos,
functionIdentifiers,
logTerm,
aggregateMapTerm,
(groupKeyTypesTerm, aggBufferTypesTerm),
(groupKeyRowType, aggBufferRowType),
aggBufferNames,
aggBufferTypes,
outputTerm,
outputType,
outputResultFromMap,
sorterTerm,
retryAppend)
HashAggCodeGenHelper.prepareMetrics(ctx, aggregateMapTerm, if (isFinal) sorterTerm else null)
val lazyInitAggBufferCode = if (auxGrouping.nonEmpty) {
s"""
|// lazy init agg buffer (with auxGrouping)
|${initedAggBuffer.code}
""".stripMargin
} else {
""
}
val processCode =
s"""
| // input field access for group key projection and aggregate buffer update
|${ctx.reuseInputUnboxingCode(inputTerm)}
| // project key from input
|$keyProjectionCode
| // look up output buffer using current group key
|$lookupInfo = $aggregateMapTerm.lookup($currentKeyTerm);
|$currentAggBufferTerm = $lookupInfo.getValue();
|
|if (!$lookupInfo.isFound()) {
| $lazyInitAggBufferCode
| // append empty agg buffer into aggregate map for current group key
| try {
| $currentAggBufferTerm =
| $aggregateMapTerm.append($lookupInfo, ${initedAggBuffer.resultTerm});
| } catch (java.io.EOFException exp) {
| $dealWithAggHashMapOOM
| }
|}
| // aggregate buffer fields access
|${ctx.reuseInputUnboxingCode(currentAggBufferTerm)}
| // do aggregate and update agg buffer
|${aggregate.code}
|""".stripMargin.trim
val endInputCode = if (isFinal) {
val memPoolTypeTerm = classOf[BytesHashMapSpillMemorySegmentPool].getName
s"""
|if ($sorterTerm == null) {
| // no spilling, output by iterating aggregate map.
| $outputResultFromMap
|} else {
| // spill last part of input' aggregation output buffer
| $sorterTerm.sortAndSpill(
| $aggregateMapTerm.getRecordAreaMemorySegments(),
| $aggregateMapTerm.getNumElements(),
| new $memPoolTypeTerm($aggregateMapTerm.getBucketAreaMemorySegments()));
| // only release floating memory in advance.
| $aggregateMapTerm.free(true);
| // fall back to sort based aggregation
| $fallbackToSortAggCode
|}
""".stripMargin
} else {
s"$outputResultFromMap"
}
AggCodeGenHelper.generateOperator(
ctx,
className,
classOf[TableStreamOperator[RowData]].getCanonicalName,
processCode,
endInputCode,
inputType)
}