in flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnInterval.scala [389:676]
def getColumnInterval(aggregate: Aggregate, mq: RelMetadataQuery, index: Int): ValueInterval =
estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on TableAggregates.
*
* @param aggregate TableAggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on TableAggregate
*/
def getColumnInterval(
aggregate: TableAggregate,
mq: RelMetadataQuery, index: Int): ValueInterval =
estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on batch group aggregate.
*
* @param aggregate batch group aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on batch group aggregate
*/
def getColumnInterval(
aggregate: BatchExecGroupAggregateBase,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on stream group aggregate.
*
* @param aggregate stream group aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on stream group Aggregate
*/
def getColumnInterval(
aggregate: StreamExecGroupAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on stream group table aggregate.
*
* @param aggregate stream group table aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on stream group TableAggregate
*/
def getColumnInterval(
aggregate: StreamExecGroupTableAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on stream local group aggregate.
*
* @param aggregate stream local group aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on stream local group Aggregate
*/
def getColumnInterval(
aggregate: StreamExecLocalGroupAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on stream global group aggregate.
*
* @param aggregate stream global group aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on stream global group Aggregate
*/
def getColumnInterval(
aggregate: StreamExecGlobalGroupAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(aggregate, mq, index)
/**
* Gets interval of the given column on window aggregate.
*
* @param agg window aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on window Aggregate
*/
def getColumnInterval(
agg: WindowAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)
/**
* Gets interval of the given column on batch window aggregate.
*
* @param agg batch window aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on batch window Aggregate
*/
def getColumnInterval(
agg: BatchExecWindowAggregateBase,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)
/**
* Gets interval of the given column on stream window aggregate.
*
* @param agg stream window aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on stream window Aggregate
*/
def getColumnInterval(
agg: StreamExecGroupWindowAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)
/**
* Gets interval of the given column on stream window table aggregate.
*
* @param agg stream window table aggregate RelNode
* @param mq RelMetadataQuery instance
* @param index the index of the given column
* @return interval of the given column on stream window Aggregate
*/
def getColumnInterval(
agg: StreamExecGroupWindowTableAggregate,
mq: RelMetadataQuery,
index: Int): ValueInterval = estimateColumnIntervalOfAggregate(agg, mq, index)
private def estimateColumnIntervalOfAggregate(
aggregate: SingleRel,
mq: RelMetadataQuery,
index: Int): ValueInterval = {
val input = aggregate.getInput
val fmq = FlinkRelMetadataQuery.reuseOrCreate(mq)
val groupSet = aggregate match {
case agg: StreamExecGroupAggregate => agg.grouping
case agg: StreamExecLocalGroupAggregate => agg.grouping
case agg: StreamExecGlobalGroupAggregate => agg.grouping
case agg: StreamExecIncrementalGroupAggregate => agg.partialAggGrouping
case agg: StreamExecGroupWindowAggregate => agg.getGrouping
case agg: BatchExecGroupAggregateBase => agg.getGrouping ++ agg.getAuxGrouping
case agg: Aggregate => AggregateUtil.checkAndGetFullGroupSet(agg)
case agg: BatchExecLocalSortWindowAggregate =>
// grouping + assignTs + auxGrouping
agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping
case agg: BatchExecLocalHashWindowAggregate =>
// grouping + assignTs + auxGrouping
agg.getGrouping ++ Array(agg.inputTimeFieldIndex) ++ agg.getAuxGrouping
case agg: BatchExecWindowAggregateBase => agg.getGrouping ++ agg.getAuxGrouping
case agg: TableAggregate => agg.getGroupSet.toArray
case agg: StreamExecGroupTableAggregate => agg.grouping
case agg: StreamExecGroupWindowTableAggregate => agg.getGrouping
}
if (index < groupSet.length) {
// estimates group keys according to the input relNodes.
val sourceFieldIndex = groupSet(index)
fmq.getColumnInterval(input, sourceFieldIndex)
} else {
def getAggCallFromLocalAgg(
index: Int,
aggCalls: Seq[AggregateCall],
inputType: RelDataType): AggregateCall = {
val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
aggCalls, inputType)
if (outputIndexToAggCallIndexMap.containsKey(index)) {
val realIndex = outputIndexToAggCallIndexMap.get(index)
aggCalls(realIndex)
} else {
null
}
}
def getAggCallIndexInLocalAgg(
index: Int,
globalAggCalls: Seq[AggregateCall],
inputRowType: RelDataType): Integer = {
val outputIndexToAggCallIndexMap = AggregateUtil.getOutputIndexToAggCallIndexMap(
globalAggCalls, inputRowType)
outputIndexToAggCallIndexMap.foreach {
case (k, v) => if (v == index) {
return k
}
}
null.asInstanceOf[Integer]
}
if (index < groupSet.length) {
// estimates group keys according to the input relNodes.
val sourceFieldIndex = groupSet(index)
fmq.getColumnInterval(aggregate.getInput, sourceFieldIndex)
} else {
val aggCallIndex = index - groupSet.length
val aggCall = aggregate match {
case agg: StreamExecGroupAggregate if agg.aggCalls.length > aggCallIndex =>
agg.aggCalls(aggCallIndex)
case agg: StreamExecGlobalGroupAggregate
if agg.globalAggInfoList.getActualAggregateCalls.length > aggCallIndex =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex, agg.globalAggInfoList.getActualAggregateCalls, agg.inputRowType)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
} else {
null
}
case agg: StreamExecLocalGroupAggregate =>
getAggCallFromLocalAgg(
aggCallIndex, agg.aggInfoList.getActualAggregateCalls, agg.getInput.getRowType)
case agg: StreamExecIncrementalGroupAggregate
if agg.partialAggInfoList.getActualAggregateCalls.length > aggCallIndex =>
agg.partialAggInfoList.getActualAggregateCalls(aggCallIndex)
case agg: StreamExecGroupWindowAggregate if agg.aggCalls.length > aggCallIndex =>
agg.aggCalls(aggCallIndex)
case agg: BatchExecLocalHashAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
case agg: BatchExecHashAggregate if agg.isMerge =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
} else {
null
}
case agg: BatchExecLocalSortAggregate =>
getAggCallFromLocalAgg(aggCallIndex, agg.getAggCallList, agg.getInput.getRowType)
case agg: BatchExecSortAggregate if agg.isMerge =>
val aggCallIndexInLocalAgg = getAggCallIndexInLocalAgg(
aggCallIndex, agg.getAggCallList, agg.aggInputRowType)
if (aggCallIndexInLocalAgg != null) {
return fmq.getColumnInterval(agg.getInput, groupSet.length + aggCallIndexInLocalAgg)
} else {
null
}
case agg: BatchExecGroupAggregateBase if agg.getAggCallList.length > aggCallIndex =>
agg.getAggCallList(aggCallIndex)
case agg: Aggregate =>
val (_, aggCalls) = AggregateUtil.checkAndSplitAggCalls(agg)
if (aggCalls.length > aggCallIndex) {
aggCalls(aggCallIndex)
} else {
null
}
case agg: BatchExecWindowAggregateBase if agg.getAggCallList.length > aggCallIndex =>
agg.getAggCallList(aggCallIndex)
case _ => null
}
if (aggCall != null) {
aggCall.getAggregation.getKind match {
case SUM | SUM0 =>
val inputInterval = fmq.getColumnInterval(input, aggCall.getArgList.get(0))
if (inputInterval != null) {
inputInterval match {
case withLower: WithLower if withLower.lower.isInstanceOf[Number] =>
if (withLower.lower.asInstanceOf[Number].doubleValue() >= 0.0) {
RightSemiInfiniteValueInterval(withLower.lower, withLower.includeLower)
} else {
null.asInstanceOf[ValueInterval]
}
case withUpper: WithUpper if withUpper.upper.isInstanceOf[Number] =>
if (withUpper.upper.asInstanceOf[Number].doubleValue() <= 0.0) {
LeftSemiInfiniteValueInterval(withUpper.upper, withUpper.includeUpper)
} else {
null
}
case _ => null
}
} else {
null
}
case COUNT =>
RightSemiInfiniteValueInterval(JBigDecimal.valueOf(0), includeLower = true)
// TODO add more built-in agg functions
case _ => null
}
} else {
null
}
}
}
}