in flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PushPartitionIntoLegacyTableSourceScanRule.scala [70:221]
override def onMatch(call: RelOptRuleCall): Unit = {
val filter: Filter = call.rel(0)
val scan: LogicalTableScan = call.rel(1)
val context = call.getPlanner.getContext.unwrap(classOf[FlinkContext])
val config = context.getTableConfig
val tableSourceTable = scan.getTable.unwrap(classOf[LegacyTableSourceTable[_]])
val tableIdentifier = tableSourceTable.tableIdentifier
val catalogOption = toScala(context.getCatalogManager.getCatalog(
tableIdentifier.getCatalogName))
val partitionFieldNames = tableSourceTable.catalogTable.getPartitionKeys.toSeq.toArray[String]
val tableSource = tableSourceTable.tableSource.asInstanceOf[PartitionableTableSource]
val inputFieldType = filter.getInput.getRowType
val inputFields = inputFieldType.getFieldNames.toList.toArray
val relBuilder = call.builder()
val rexBuilder = relBuilder.getRexBuilder
val maxCnfNodeCount = FlinkRelOptUtil.getMaxCnfNodeCount(scan)
val (partitionPredicates, nonPartitionPredicates) =
RexNodeExtractor.extractPartitionPredicateList(
filter.getCondition,
maxCnfNodeCount,
inputFields,
rexBuilder,
partitionFieldNames
)
val partitionPredicate = RexUtil.composeConjunction(rexBuilder, partitionPredicates)
if (partitionPredicate.isAlwaysTrue) {
// no partition predicates in filter
return
}
val partitionFieldTypes = partitionFieldNames.map { name =>
val index = inputFieldType.getFieldNames.indexOf(name)
require(index >= 0, s"$name is not found in ${inputFieldType.getFieldNames.mkString(", ")}")
inputFieldType.getFieldList.get(index).getType
}.map(FlinkTypeFactory.toLogicalType)
val partitionsFromSource = try {
Some(tableSource.getPartitions)
} catch {
case _: UnsupportedOperationException => None
}
def getAllPartitions: util.List[util.Map[String, String]] = {
partitionsFromSource match {
case Some(parts) => parts
case None => catalogOption match {
case Some(catalog) =>
catalog.listPartitions(tableIdentifier.toObjectPath).map(_.getPartitionSpec).toList
case None => throw new TableException(s"The $tableSource must be a catalog.")
}
}
}
def internalPartitionPrune(): util.List[util.Map[String, String]] = {
val allPartitions = getAllPartitions
val finalPartitionPredicate = adjustPartitionPredicate(
inputFieldType.getFieldNames.toList.toArray,
partitionFieldNames,
partitionPredicate
)
PartitionPruner.prunePartitions(
config,
partitionFieldNames,
partitionFieldTypes,
allPartitions,
finalPartitionPredicate
)
}
val remainingPartitions: util.List[util.Map[String, String]] = partitionsFromSource match {
case Some(_) => internalPartitionPrune()
case None =>
catalogOption match {
case Some(catalog) =>
val converter = new RexNodeToExpressionConverter(
inputFields,
context.getFunctionCatalog,
context.getCatalogManager,
TimeZone.getTimeZone(config.getLocalTimeZone))
def toExpressions: Option[Seq[Expression]] = {
val expressions = new mutable.ArrayBuffer[Expression]()
for (predicate <- partitionPredicates) {
predicate.accept(converter) match {
case Some(expr) => expressions.add(expr)
case None => return None
}
}
Some(expressions)
}
toExpressions match {
case Some(expressions) =>
try {
catalog
.listPartitionsByFilter(tableIdentifier.toObjectPath, expressions)
.map(_.getPartitionSpec)
} catch {
case _: UnsupportedOperationException => internalPartitionPrune()
}
case None => internalPartitionPrune()
}
case None => internalPartitionPrune()
}
}
val newTableSource = tableSource.applyPartitionPruning(remainingPartitions)
if (newTableSource.explainSource().equals(tableSourceTable.tableSource.explainSource())) {
throw new TableException("Failed to push partition into table source! "
+ "table source with pushdown capability must override and change "
+ "explainSource() API to explain the pushdown applied!")
}
val statistic = tableSourceTable.getStatistic
val newStatistic = {
val tableStats = catalogOption match {
case Some(catalog) =>
def mergePartitionStats(): TableStats = {
var stats: TableStats = null
for (p <- remainingPartitions) {
getPartitionStats(catalog, tableIdentifier, p) match {
case Some(currStats) =>
if (stats == null) {
stats = currStats
} else {
stats = stats.merge(currStats)
}
case None => return null
}
}
stats
}
mergePartitionStats()
case None => null
}
FlinkStatistic.builder().statistic(statistic).tableStats(tableStats).build()
}
val newTableSourceTable = tableSourceTable.copy(newTableSource, newStatistic)
val newScan = new LogicalTableScan(scan.getCluster, scan.getTraitSet, newTableSourceTable)
// check whether framework still need to do a filter
val nonPartitionPredicate = RexUtil.composeConjunction(rexBuilder, nonPartitionPredicates)
if (nonPartitionPredicate.isAlwaysTrue) {
call.transformTo(newScan)
} else {
val newFilter = filter.copy(filter.getTraitSet, newScan, nonPartitionPredicate)
call.transformTo(newFilter)
}
}