private[this] def createEventTimeSessionWindowDataSet()

in flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala [236:409]


  private[this] def createEventTimeSessionWindowDataSet(
      config: TableConfig,
      nullableInput: Boolean,
      inputTypeInfo: TypeInformation[_ <: Any],
      constants: Option[Seq[RexLiteral]],
      inputDS: DataSet[Row],
      isParserCaseSensitive: Boolean,
      tableConfig: TableConfig): DataSet[Row] = {

    val input = inputNode.asInstanceOf[DataSetRel]

    val groupingKeys = grouping.indices.toArray
    val rowTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(getRowType)

    // create mapFunction for initializing the aggregations
    val mapFunction = createDataSetWindowPrepareMapFunction(
      config,
      nullableInput,
      inputTypeInfo,
      constants,
      window,
      namedAggregates,
      grouping,
      input.getRowType,
      inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
      isParserCaseSensitive,
      tableConfig)

    val mappedInput = inputDS.map(mapFunction).name(prepareOperatorName)

    val mapReturnType = mapFunction.asInstanceOf[ResultTypeQueryable[Row]].getProducedType

    // the position of the rowtime field in the intermediate result for map output
    val rowTimeFieldPos = mapReturnType.getArity - 1

    // do incremental aggregation
    if (doAllSupportPartialMerge(
      namedAggregates.map(_.getKey),
      inputType,
      grouping.length,
      tableConfig)) {

      // gets the window-start and window-end position  in the intermediate result.
      val windowStartPos = rowTimeFieldPos
      val windowEndPos = windowStartPos + 1
      // grouping window
      if (groupingKeys.length > 0) {
        // create groupCombineFunction for combine the aggregations
        val combineGroupFunction = createDataSetWindowAggregationCombineFunction(
          config,
          nullableInput,
          inputTypeInfo,
          constants,
          window,
          namedAggregates,
          input.getRowType,
          inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
          grouping,
          tableConfig)

        // create groupReduceFunction for calculating the aggregations
        val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
          config,
          nullableInput,
          inputTypeInfo,
          constants,
          window,
          namedAggregates,
          input.getRowType,
          inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
          rowRelDataType,
          grouping,
          namedProperties,
          tableConfig,
          isInputCombined = true)

        mappedInput
          .groupBy(groupingKeys: _*)
          .sortGroup(rowTimeFieldPos, Order.ASCENDING)
          .combineGroup(combineGroupFunction)
          .groupBy(groupingKeys: _*)
          .sortGroup(windowStartPos, Order.ASCENDING)
          .sortGroup(windowEndPos, Order.ASCENDING)
          .reduceGroup(groupReduceFunction)
          .returns(rowTypeInfo)
          .name(aggregateOperatorName)
      } else {
        // non-grouping window
        val mapPartitionFunction = createDataSetWindowAggregationMapPartitionFunction(
          config,
          nullableInput,
          inputTypeInfo,
          constants,
          window,
          namedAggregates,
          input.getRowType,
          inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
          grouping,
          tableConfig)

        // create groupReduceFunction for calculating the aggregations
        val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
          config,
          nullableInput,
          inputTypeInfo,
          constants,
          window,
          namedAggregates,
          input.getRowType,
          inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
          rowRelDataType,
          grouping,
          namedProperties,
          tableConfig,
          isInputCombined = true)

        mappedInput.sortPartition(rowTimeFieldPos, Order.ASCENDING)
          .mapPartition(mapPartitionFunction)
          .sortPartition(windowStartPos, Order.ASCENDING).setParallelism(1)
          .sortPartition(windowEndPos, Order.ASCENDING).setParallelism(1)
          .reduceGroup(groupReduceFunction)
          .returns(rowTypeInfo)
          .name(aggregateOperatorName)
          .asInstanceOf[DataSet[Row]]
      }
    // do non-incremental aggregation
    } else {
      // grouping window
      if (groupingKeys.length > 0) {

        // create groupReduceFunction for calculating the aggregations
        val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
          config,
          nullableInput,
          inputTypeInfo,
          constants,
          window,
          namedAggregates,
          input.getRowType,
          inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
          rowRelDataType,
          grouping,
          namedProperties,
          tableConfig)

        mappedInput.groupBy(groupingKeys: _*)
          .sortGroup(rowTimeFieldPos, Order.ASCENDING)
          .reduceGroup(groupReduceFunction)
          .returns(rowTypeInfo)
          .name(aggregateOperatorName)
      } else {
        // non-grouping window
        val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction(
          config,
          nullableInput,
          inputTypeInfo,
          constants,
          window,
          namedAggregates,
          input.getRowType,
          inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
          rowRelDataType,
          grouping,
          namedProperties,
          tableConfig)

        mappedInput.sortPartition(rowTimeFieldPos, Order.ASCENDING).setParallelism(1)
          .reduceGroup(groupReduceFunction)
          .returns(rowTypeInfo)
          .name(aggregateOperatorName)
          .asInstanceOf[DataSet[Row]]
      }
    }
  }