in flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala [723:1039]
override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression =
throw new CodeGenException("Local variables are not supported yet.")
override def visitRangeRef(rangeRef: RexRangeRef): GeneratedExpression =
throw new CodeGenException("Range references are not supported yet.")
override def visitDynamicParam(dynamicParam: RexDynamicParam): GeneratedExpression =
throw new CodeGenException("Dynamic parameter references are not supported yet.")
override def visitCall(call: RexCall): GeneratedExpression = {
// special case: time materialization
if (call.getOperator == ProctimeSqlFunction) {
return generateProctimeTimestamp()
}
val resultType = FlinkTypeFactory.toTypeInfo(call.getType)
// convert operands and help giving untyped NULL literals a type
val operands = call.getOperands.zipWithIndex.map {
// this helps e.g. for AS(null)
// we might need to extend this logic in case some rules do not create typed NULLs
case (operandLiteral: RexLiteral, 0) if
operandLiteral.getType.getSqlTypeName == SqlTypeName.NULL &&
call.getOperator.getReturnTypeInference == ReturnTypes.ARG0 =>
generateNullLiteral(resultType)
case (o@_, _) =>
o.accept(this)
}
call.getOperator match {
// arithmetic
case PLUS if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left)
requireNumeric(right)
generateArithmeticOperator("+", nullCheck, resultType, left, right, config)
case PLUS | DATETIME_PLUS if isTemporal(resultType) =>
val left = operands.head
val right = operands(1)
requireTemporal(left)
requireTemporal(right)
generateTemporalPlusMinus(plus = true, nullCheck, resultType, left, right, config)
case MINUS if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left)
requireNumeric(right)
generateArithmeticOperator("-", nullCheck, resultType, left, right, config)
case MINUS | MINUS_DATE if isTemporal(resultType) =>
val left = operands.head
val right = operands(1)
requireTemporal(left)
requireTemporal(right)
generateTemporalPlusMinus(plus = false, nullCheck, resultType, left, right, config)
case MULTIPLY if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left)
requireNumeric(right)
generateArithmeticOperator("*", nullCheck, resultType, left, right, config)
case MULTIPLY if isTimeInterval(resultType) =>
val left = operands.head
val right = operands(1)
requireTimeInterval(left)
requireNumeric(right)
generateArithmeticOperator("*", nullCheck, resultType, left, right, config)
case DIVIDE | DIVIDE_INTEGER if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left)
requireNumeric(right)
generateArithmeticOperator("/", nullCheck, resultType, left, right, config)
case MOD if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left)
requireNumeric(right)
generateArithmeticOperator("%", nullCheck, resultType, left, right, config)
case UNARY_MINUS if isNumeric(resultType) =>
val operand = operands.head
requireNumeric(operand)
generateUnaryArithmeticOperator("-", nullCheck, resultType, operand)
case UNARY_MINUS if isTimeInterval(resultType) =>
val operand = operands.head
requireTimeInterval(operand)
generateUnaryIntervalPlusMinus(plus = false, nullCheck, operand)
case UNARY_PLUS if isNumeric(resultType) =>
val operand = operands.head
requireNumeric(operand)
generateUnaryArithmeticOperator("+", nullCheck, resultType, operand)
case UNARY_PLUS if isTimeInterval(resultType) =>
val operand = operands.head
requireTimeInterval(operand)
generateUnaryIntervalPlusMinus(plus = true, nullCheck, operand)
// comparison
case EQUALS =>
val left = operands.head
val right = operands(1)
generateEquals(nullCheck, left, right)
case IS_NOT_DISTINCT_FROM =>
val left = operands.head
val right = operands(1)
generateIsNotDistinctFrom(nullCheck, left, right);
case NOT_EQUALS =>
val left = operands.head
val right = operands(1)
generateNotEquals(nullCheck, left, right)
case GREATER_THAN =>
val left = operands.head
val right = operands(1)
requireComparable(left)
requireComparable(right)
generateComparison(">", nullCheck, left, right)
case GREATER_THAN_OR_EQUAL =>
val left = operands.head
val right = operands(1)
requireComparable(left)
requireComparable(right)
generateComparison(">=", nullCheck, left, right)
case LESS_THAN =>
val left = operands.head
val right = operands(1)
requireComparable(left)
requireComparable(right)
generateComparison("<", nullCheck, left, right)
case LESS_THAN_OR_EQUAL =>
val left = operands.head
val right = operands(1)
requireComparable(left)
requireComparable(right)
generateComparison("<=", nullCheck, left, right)
case IS_NULL =>
val operand = operands.head
generateIsNull(nullCheck, operand)
case IS_NOT_NULL =>
val operand = operands.head
generateIsNotNull(nullCheck, operand)
// logic
case AND =>
operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) =>
requireBoolean(left)
requireBoolean(right)
generateAnd(nullCheck, left, right)
}
case OR =>
operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) =>
requireBoolean(left)
requireBoolean(right)
generateOr(nullCheck, left, right)
}
case NOT =>
val operand = operands.head
requireBoolean(operand)
generateNot(nullCheck, operand)
case CASE =>
generateIfElse(nullCheck, operands, resultType)
case IS_TRUE =>
val operand = operands.head
requireBoolean(operand)
generateIsTrue(operand)
case IS_NOT_TRUE =>
val operand = operands.head
requireBoolean(operand)
generateIsNotTrue(operand)
case IS_FALSE =>
val operand = operands.head
requireBoolean(operand)
generateIsFalse(operand)
case IS_NOT_FALSE =>
val operand = operands.head
requireBoolean(operand)
generateIsNotFalse(operand)
case IN =>
val left = operands.head
val right = operands.tail
generateIn(this, left, right)
case NOT_IN =>
val left = operands.head
val right = operands.tail
generateNot(nullCheck, generateIn(this, left, right))
// casting
case CAST | REINTERPRET =>
val operand = operands.head
generateCast(nullCheck, operand, resultType)
// as / renaming
case AS =>
operands.head
// string arithmetic
case CONCAT =>
val left = operands.head
val right = operands(1)
requireString(left)
generateArithmeticOperator("+", nullCheck, resultType, left, right, config)
// rows
case ROW =>
generateRow(this, resultType, operands)
// arrays
case ARRAY_VALUE_CONSTRUCTOR =>
generateArray(this, resultType, operands)
// maps
case MAP_VALUE_CONSTRUCTOR =>
generateMap(this, resultType, operands)
case ITEM =>
operands.head.resultType match {
case t: TypeInformation[_] if isArray(t) =>
val array = operands.head
val index = operands(1)
requireInteger(index)
generateArrayElementAt(this, array, index)
case t: TypeInformation[_] if isMap(t) =>
val key = operands(1)
generateMapGet(this, operands.head, key)
case _ => throw new CodeGenException("Expect an array or a map.")
}
case CARDINALITY =>
operands.head.resultType match {
case t: TypeInformation[_] if isArray(t) =>
val array = operands.head
generateArrayCardinality(nullCheck, array)
case t: TypeInformation[_] if isMap(t) =>
val map = operands.head
generateMapCardinality(nullCheck, map)
case _ => throw new CodeGenException("Expect an array or a map.")
}
case ELEMENT =>
val array = operands.head
requireArray(array)
generateArrayElement(this, array)
case DOT =>
// Due to https://issues.apache.org/jira/browse/CALCITE-2162, expression such as
// "array[1].a.b" won't work now.
if (operands.size > 2) {
throw new CodeGenException(
"A DOT operator with more than 2 operands is not supported yet.")
}
val fieldName = call.operands.get(1).asInstanceOf[RexLiteral].getValueAs(classOf[String])
val fieldIdx = operands
.head
.resultType
.asInstanceOf[CompositeType[_]]
.getFieldIndex(fieldName)
generateFieldAccess(operands.head, fieldIdx)
case ScalarSqlFunctions.CONCAT =>
generateConcat(this.nullCheck, operands)
case ScalarSqlFunctions.CONCAT_WS =>
generateConcatWs(operands)
case StreamRecordTimestampSqlFunction =>
generateStreamRecordRowtimeAccess()
// advanced scalar functions
case sqlOperator: SqlOperator =>
val callGen = FunctionGenerator.getCallGenerator(
sqlOperator,
operands.map(_.resultType),
resultType)
callGen
.getOrElse(throw new CodeGenException(s"Unsupported call: $sqlOperator \n" +
s"If you think this function should be supported, " +
s"you can create an issue and start a discussion for it."))
.generate(this, operands)
// unknown or invalid
case call@_ =>
throw new CodeGenException(s"Unsupported call: $call")
}
}