in flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala [236:409]
private[this] def createEventTimeSessionWindowDataSet(
config: TableConfig,
nullableInput: Boolean,
inputTypeInfo: TypeInformation[_ <: Any],
constants: Option[Seq[RexLiteral]],
inputDS: DataSet[Row],
isParserCaseSensitive: Boolean,
tableConfig: TableConfig): DataSet[Row] = {
val input = inputNode.asInstanceOf[DataSetRel]
val groupingKeys = grouping.indices.toArray
val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType)
// create mapFunction for initializing the aggregations
val mapFunction = createDataSetWindowPrepareMapFunction(
config,
nullableInput,
inputTypeInfo,
constants,
window,
namedAggregates,
grouping,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
isParserCaseSensitive,
tableConfig)
val mappedInput = inputDS.map(mapFunction).name(prepareOperatorName)
val mapReturnType = mapFunction.asInstanceOf[ResultTypeQueryable[Row]].getProducedType
// the position of the rowtime field in the intermediate result for map output
val rowTimeFieldPos = mapReturnType.getArity - 1
// do incremental aggregation
if (doAllSupportPartialMerge(
namedAggregates.map(_.getKey),
inputType,
grouping.length,
tableConfig)) {
// gets the window-start and window-end position in the intermediate result.
val windowStartPos = rowTimeFieldPos
val windowEndPos = windowStartPos + 1
// grouping window
if (groupingKeys.length > 0) {
// create groupCombineFunction for combine the aggregations
val combineGroupFunction = createDataSetWindowAggregationCombineFunction(
config,
nullableInput,
inputTypeInfo,
constants,
window,
namedAggregates,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
grouping,
tableConfig)
// create groupReduceFunction for calculating the aggregations
val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
config,
nullableInput,
inputTypeInfo,
constants,
window,
namedAggregates,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
rowRelDataType,
grouping,
namedProperties,
tableConfig,
isInputCombined = true)
mappedInput
.groupBy(groupingKeys: _*)
.sortGroup(rowTimeFieldPos, Order.ASCENDING)
.combineGroup(combineGroupFunction)
.groupBy(groupingKeys: _*)
.sortGroup(windowStartPos, Order.ASCENDING)
.sortGroup(windowEndPos, Order.ASCENDING)
.reduceGroup(groupReduceFunction)
.returns(rowTypeInfo)
.name(aggregateOperatorName)
} else {
// non-grouping window
val mapPartitionFunction = createDataSetWindowAggregationMapPartitionFunction(
config,
nullableInput,
inputTypeInfo,
constants,
window,
namedAggregates,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
grouping,
tableConfig)
// create groupReduceFunction for calculating the aggregations
val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
config,
nullableInput,
inputTypeInfo,
constants,
window,
namedAggregates,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
rowRelDataType,
grouping,
namedProperties,
tableConfig,
isInputCombined = true)
mappedInput.sortPartition(rowTimeFieldPos, Order.ASCENDING)
.mapPartition(mapPartitionFunction)
.sortPartition(windowStartPos, Order.ASCENDING).setParallelism(1)
.sortPartition(windowEndPos, Order.ASCENDING).setParallelism(1)
.reduceGroup(groupReduceFunction)
.returns(rowTypeInfo)
.name(aggregateOperatorName)
.asInstanceOf[DataSet[Row]]
}
// do non-incremental aggregation
} else {
// grouping window
if (groupingKeys.length > 0) {
// create groupReduceFunction for calculating the aggregations
val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
config,
nullableInput,
inputTypeInfo,
constants,
window,
namedAggregates,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
rowRelDataType,
grouping,
namedProperties,
tableConfig)
mappedInput.groupBy(groupingKeys: _*)
.sortGroup(rowTimeFieldPos, Order.ASCENDING)
.reduceGroup(groupReduceFunction)
.returns(rowTypeInfo)
.name(aggregateOperatorName)
} else {
// non-grouping window
val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
config,
nullableInput,
inputTypeInfo,
constants,
window,
namedAggregates,
input.getRowType,
inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
rowRelDataType,
grouping,
namedProperties,
tableConfig)
mappedInput.sortPartition(rowTimeFieldPos, Order.ASCENDING).setParallelism(1)
.reduceGroup(groupReduceFunction)
.returns(rowTypeInfo)
.name(aggregateOperatorName)
.asInstanceOf[DataSet[Row]]
}
}
}