def getColumnInterval()

in flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala [389:676]


  def getColumnInterval(aggregate: Aggregate, mq: RelMetadataQuery, index: Int): ValueInterval =
    estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
    * Gets interval of the given column on TableAggregates.
    *
    * @param aggregate TableAggregate RelNode
    * @param mq        RelMetadataQuery instance
    * @param index     the index of the given column
    * @return interval of the given column on TableAggregate
    */
  def getColumnInterval(
      aggregate: TableAggregate,
      mq: RelMetadataQuery, index: Int): ValueInterval =

    estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
    * Gets interval of the given column on batch group aggregate.
    *
    * @param aggregate batch group aggregate RelNode
    * @param mq        RelMetadataQuery instance
    * @param index     the index of the given column
    * @return interval of the given column on batch group aggregate
    */
  def getColumnInterval(
      aggregate: BatchExecGroupAggregateBase,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
    * Gets interval of the given column on stream group aggregate.
    *
    * @param aggregate stream group aggregate RelNode
    * @param mq        RelMetadataQuery instance
    * @param index     the index of the given column
    * @return interval of the given column on stream group Aggregate
    */
  def getColumnInterval(
      aggregate: StreamExecGroupAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
    * Gets interval of the given column on stream group table aggregate.
    *
    * @param aggregate stream group table aggregate RelNode
    * @param mq        RelMetadataQuery instance
    * @param index     the index of the given column
    * @return interval of the given column on stream group TableAggregate
    */
  def getColumnInterval(
    aggregate: StreamExecGroupTableAggregate,
    mq: RelMetadataQuery,
    index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
    * Gets interval of the given column on stream local group aggregate.
    *
    * @param aggregate stream local group aggregate RelNode
    * @param mq        RelMetadataQuery instance
    * @param index     the index of the given column
    * @return interval of the given column on stream local group Aggregate
    */
  def getColumnInterval(
      aggregate: StreamExecLocalGroupAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
    * Gets interval of the given column on stream global group aggregate.
    *
    * @param aggregate stream global group aggregate RelNode
    * @param mq        RelMetadataQuery instance
    * @param index     the index of the given column
    * @return interval of the given column on stream global group Aggregate
    */
  def getColumnInterval(
      aggregate: StreamExecGlobalGroupAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)

  /**
    * Gets interval of the given column on window aggregate.
    *
    * @param agg   window aggregate RelNode
    * @param mq    RelMetadataQuery instance
    * @param index the index of the given column
    * @return interval of the given column on window Aggregate
    */
  def getColumnInterval(
      agg: WindowAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)

  /**
    * Gets interval of the given column on batch window aggregate.
    *
    * @param agg   batch window aggregate RelNode
    * @param mq    RelMetadataQuery instance
    * @param index the index of the given column
    * @return interval of the given column on batch window Aggregate
    */
  def getColumnInterval(
      agg: BatchExecWindowAggregateBase,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)

  /**
    * Gets interval of the given column on stream window aggregate.
    *
    * @param agg   stream window aggregate RelNode
    * @param mq    RelMetadataQuery instance
    * @param index the index of the given column
    * @return interval of the given column on stream window Aggregate
    */
  def getColumnInterval(
      agg: StreamExecGroupWindowAggregate,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)

  /**
    * Gets interval of the given column on stream window table aggregate.
    *
    * @param agg   stream window table aggregate RelNode
    * @param mq    RelMetadataQuery instance
    * @param index the index of the given column
    * @return interval of the given column on stream window Aggregate
    */
  def getColumnInterval(
    agg: StreamExecGroupWindowTableAggregate,
    mq: RelMetadataQuery,
    index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)

  private def estimateColumnIntervalOfAggregate(
      aggregate: SingleRel,
      mq: RelMetadataQuery,
      index: Int): ValueInterval = {
    val input = aggregate.getInput
    val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
    val groupSet = aggregate match {
      case agg: StreamExecGroupAggregate => agg.grouping
      case agg: StreamExecLocalGroupAggregate => agg.grouping
      case agg: StreamExecGlobalGroupAggregate => agg.grouping
      case agg: StreamExecIncrementalGroupAggregate => agg.partialAggGrouping
      case agg: StreamExecGroupWindowAggregate => agg.getGrouping
      case agg: BatchExecGroupAggregateBase => agg.getGrouping ++ agg.getAuxGrouping
      case agg: Aggregate => AggregateUtil.checkAndGetFullGroupSet(agg)
      case agg: BatchExecLocalSortWindowAggregate =>
        // grouping + assignTs + auxGrouping
        agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping
      case agg: BatchExecLocalHashWindowAggregate =>
        // grouping + assignTs + auxGrouping
        agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping
      case agg: BatchExecWindowAggregateBase => agg.getGrouping ++ agg.getAuxGrouping
      case agg: TableAggregate => agg.getGroupSet.toArray
      case agg: StreamExecGroupTableAggregate => agg.grouping
      case agg: StreamExecGroupWindowTableAggregate => agg.getGrouping
    }

    if (index < groupSet.length) {
      // estimates group keys according to the input relNodes.
      val sourceFieldIndex = groupSet(index)
      fmq.getColumnInterval(input, sourceFieldIndex)
    } else {
      def getAggCallFromLocalAgg(
          index: Int,
          aggCalls: Seq[AggregateCall],
          inputType: RelDataType): AggregateCall = {
        val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
          aggCalls, inputType)
        if (outputIndexToAggCallIndexMap.containsKey(index)) {
          val realIndex = outputIndexToAggCallIndexMap.get(index)
          aggCalls(realIndex)
        } else {
          null
        }
      }

      def getAggCallIndexInLocalAgg(
          index: Int,
          globalAggCalls: Seq[AggregateCall],
          inputRowType: RelDataType): Integer = {
        val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
          globalAggCalls, inputRowType)

        outputIndexToAggCallIndexMap.foreach {
          case (k, v) => if (v == index) {
            return k
          }
        }
        null.asInstanceOf[Integer]
      }

      if (index < groupSet.length) {
        // estimates group keys according to the input relNodes.
        val sourceFieldIndex = groupSet(index)
        fmq.getColumnInterval(aggregate.getInput, sourceFieldIndex)
      } else {
        val aggCallIndex = index - groupSet.length
        val aggCall = aggregate match {
          case agg: StreamExecGroupAggregate if agg.aggCalls.length > aggCallIndex =>
            agg.aggCalls(aggCallIndex)
          case agg: StreamExecGlobalGroupAggregate
            if agg.globalAggInfoList.getActualAggregateCalls.length > aggCallIndex =>
            val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
              aggCallIndex, agg.globalAggInfoList.getActualAggregateCalls, agg.inputRowType)
            if (aggCallIndexInLocalAgg != null) {
              return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
            } else {
              null
            }
          case agg: StreamExecLocalGroupAggregate =>
            getAggCallFromLocalAgg(
              aggCallIndex, agg.aggInfoList.getActualAggregateCalls, agg.getInput.getRowType)
          case agg: StreamExecIncrementalGroupAggregate
            if agg.partialAggInfoList.getActualAggregateCalls.length > aggCallIndex =>
            agg.partialAggInfoList.getActualAggregateCalls(aggCallIndex)
          case agg: StreamExecGroupWindowAggregate if agg.aggCalls.length > aggCallIndex =>
            agg.aggCalls(aggCallIndex)
          case agg: BatchExecLocalHashAggregate =>
            getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
          case agg: BatchExecHashAggregate if agg.isMerge =>
            val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
              aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
            if (aggCallIndexInLocalAgg != null) {
              return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
            } else {
              null
            }
          case agg: BatchExecLocalSortAggregate =>
            getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
          case agg: BatchExecSortAggregate if agg.isMerge =>
            val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
              aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
            if (aggCallIndexInLocalAgg != null) {
              return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
            } else {
              null
            }
          case agg: BatchExecGroupAggregateBase if agg.getAggCallList.length > aggCallIndex =>
            agg.getAggCallList(aggCallIndex)
          case agg: Aggregate =>
            val (_, aggCalls) = AggregateUtil.checkAndSplitAggCalls(agg)
            if (aggCalls.length > aggCallIndex) {
              aggCalls(aggCallIndex)
            } else {
              null
            }
          case agg: BatchExecWindowAggregateBase if agg.getAggCallList.length > aggCallIndex =>
            agg.getAggCallList(aggCallIndex)
          case _ => null
        }

        if (aggCall != null) {
          aggCall.getAggregation.getKind match {
            case SUM | SUM0 =>
              val inputInterval = fmq.getColumnInterval(input, aggCall.getArgList.get(0))
              if (inputInterval != null) {
                inputInterval match {
                  case withLower: WithLower if withLower.lower.isInstanceOf[Number] =>
                    if (withLower.lower.asInstanceOf[Number].doubleValue() >= 0.0) {
                      RightSemiInfiniteValueInterval(withLower.lower, withLower.includeLower)
                    } else {
                      null.asInstanceOf[ValueInterval]
                    }
                  case withUpper: WithUpper if withUpper.upper.isInstanceOf[Number] =>
                    if (withUpper.upper.asInstanceOf[Number].doubleValue() <= 0.0) {
                      LeftSemiInfiniteValueInterval(withUpper.upper, withUpper.includeUpper)
                    } else {
                      null
                    }
                  case _ => null
                }
              } else {
                null
              }
            case COUNT =>
              RightSemiInfiniteValueInterval(JBigDecimal.valueOf(0), includeLower = true)
            // TODO add more built-in agg functions
            case _ => null
          }
        } else {
          null
        }
      }
    }
  }