in flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala [501:825]
override def visitOver(over: RexOver): GeneratedExpression =
throw new CodeGenException("Aggregate functions over windows are not supported yet.")
override def visitSubQuery(subQuery: RexSubQuery): GeneratedExpression =
throw new CodeGenException("Subqueries are not supported yet.")
override def visitPatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression =
throw new CodeGenException("Pattern field references are not supported yet.")
// ----------------------------------------------------------------------------------------
private def generateCallExpression(
ctx: CodeGeneratorContext,
call: RexCall,
operands: Seq[GeneratedExpression],
resultType: LogicalType): GeneratedExpression = {
call.getOperator match {
// arithmetic
case PLUS if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left)
requireNumeric(right)
generateBinaryArithmeticOperator(ctx, "+", resultType, left, right)
case PLUS | DATETIME_PLUS if isTemporal(resultType) =>
val left = operands.head
val right = operands(1)
requireTemporal(left)
requireTemporal(right)
generateTemporalPlusMinus(ctx, plus = true, resultType, left, right)
case MINUS if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left)
requireNumeric(right)
generateBinaryArithmeticOperator(ctx, "-", resultType, left, right)
case MINUS | MINUS_DATE if isTemporal(resultType) =>
val left = operands.head
val right = operands(1)
requireTemporal(left)
requireTemporal(right)
generateTemporalPlusMinus(ctx, plus = false, resultType, left, right)
case MULTIPLY if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left)
requireNumeric(right)
generateBinaryArithmeticOperator(ctx, "*", resultType, left, right)
case MULTIPLY if isTimeInterval(resultType) =>
val left = operands.head
val right = operands(1)
requireTimeInterval(left)
requireNumeric(right)
generateBinaryArithmeticOperator(ctx, "*", resultType, left, right)
case DIVIDE | DIVIDE_INTEGER if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left)
requireNumeric(right)
generateBinaryArithmeticOperator(ctx, "/", resultType, left, right)
case MOD if isNumeric(resultType) =>
val left = operands.head
val right = operands(1)
requireNumeric(left)
requireNumeric(right)
generateBinaryArithmeticOperator(ctx, "%", resultType, left, right)
case UNARY_MINUS if isNumeric(resultType) =>
val operand = operands.head
requireNumeric(operand)
generateUnaryArithmeticOperator(ctx, "-", resultType, operand)
case UNARY_MINUS if isTimeInterval(resultType) =>
val operand = operands.head
requireTimeInterval(operand)
generateUnaryIntervalPlusMinus(ctx, plus = false, operand)
case UNARY_PLUS if isNumeric(resultType) =>
val operand = operands.head
requireNumeric(operand)
generateUnaryArithmeticOperator(ctx, "+", resultType, operand)
case UNARY_PLUS if isTimeInterval(resultType) =>
val operand = operands.head
requireTimeInterval(operand)
generateUnaryIntervalPlusMinus(ctx, plus = true, operand)
// comparison
case EQUALS =>
val left = operands.head
val right = operands(1)
generateEquals(ctx, left, right)
case IS_NOT_DISTINCT_FROM =>
val left = operands.head
val right = operands(1)
generateIsNotDistinctFrom(ctx, left, right)
case NOT_EQUALS =>
val left = operands.head
val right = operands(1)
generateNotEquals(ctx, left, right)
case GREATER_THAN =>
val left = operands.head
val right = operands(1)
requireComparable(left)
requireComparable(right)
generateComparison(ctx, ">", left, right)
case GREATER_THAN_OR_EQUAL =>
val left = operands.head
val right = operands(1)
requireComparable(left)
requireComparable(right)
generateComparison(ctx, ">=", left, right)
case LESS_THAN =>
val left = operands.head
val right = operands(1)
requireComparable(left)
requireComparable(right)
generateComparison(ctx, "<", left, right)
case LESS_THAN_OR_EQUAL =>
val left = operands.head
val right = operands(1)
requireComparable(left)
requireComparable(right)
generateComparison(ctx, "<=", left, right)
case IS_NULL =>
val operand = operands.head
generateIsNull(ctx, operand)
case IS_NOT_NULL =>
val operand = operands.head
generateIsNotNull(ctx, operand)
// logic
case AND =>
operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) =>
requireBoolean(left)
requireBoolean(right)
generateAnd(ctx, left, right)
}
case OR =>
operands.reduceLeft { (left: GeneratedExpression, right: GeneratedExpression) =>
requireBoolean(left)
requireBoolean(right)
generateOr(ctx, left, right)
}
case NOT =>
val operand = operands.head
requireBoolean(operand)
generateNot(ctx, operand)
case CASE =>
generateIfElse(ctx, operands, resultType)
case IS_TRUE =>
val operand = operands.head
requireBoolean(operand)
generateIsTrue(operand)
case IS_NOT_TRUE =>
val operand = operands.head
requireBoolean(operand)
generateIsNotTrue(operand)
case IS_FALSE =>
val operand = operands.head
requireBoolean(operand)
generateIsFalse(operand)
case IS_NOT_FALSE =>
val operand = operands.head
requireBoolean(operand)
generateIsNotFalse(operand)
case IN =>
val left = operands.head
val right = operands.tail
generateIn(ctx, left, right)
case NOT_IN =>
val left = operands.head
val right = operands.tail
generateNot(ctx, generateIn(ctx, left, right))
// casting
case CAST =>
val operand = operands.head
generateCast(ctx, operand, resultType)
// Reinterpret
case REINTERPRET =>
val operand = operands.head
generateReinterpret(ctx, operand, resultType)
// as / renaming
case AS =>
operands.head
// rows
case ROW =>
generateRow(ctx, resultType, operands)
// arrays
case ARRAY_VALUE_CONSTRUCTOR =>
generateArray(ctx, resultType, operands)
// maps
case MAP_VALUE_CONSTRUCTOR =>
generateMap(ctx, resultType, operands)
case ITEM =>
operands.head.resultType.getTypeRoot match {
case LogicalTypeRoot.ARRAY =>
val array = operands.head
val index = operands(1)
requireInteger(index)
generateArrayElementAt(ctx, array, index)
case LogicalTypeRoot.MAP =>
val key = operands(1)
generateMapGet(ctx, operands.head, key)
case LogicalTypeRoot.ROW | LogicalTypeRoot.STRUCTURED_TYPE =>
generateDot(ctx, operands)
case _ => throw new CodeGenException("Expect an array or a map.")
}
case CARDINALITY =>
operands.head.resultType match {
case t: LogicalType if TypeCheckUtils.isArray(t) =>
val array = operands.head
generateArrayCardinality(ctx, array)
case t: LogicalType if TypeCheckUtils.isMap(t) =>
val map = operands.head
generateMapCardinality(ctx, map)
case _ => throw new CodeGenException("Expect an array or a map.")
}
case ELEMENT =>
val array = operands.head
requireArray(array)
generateArrayElement(ctx, array)
case DOT =>
generateDot(ctx, operands)
case PROCTIME =>
// attribute is proctime indicator.
// We use a null literal and generate a timestamp when we need it.
generateNullLiteral(
new TimestampType(true, TimestampKind.PROCTIME, 3),
ctx.nullCheck)
case PROCTIME_MATERIALIZE =>
generateProctimeTimestamp(ctx, contextTerm)
case STREAMRECORD_TIMESTAMP =>
generateRowtimeAccess(ctx, contextTerm)
case _: SqlThrowExceptionFunction =>
val nullValue = generateNullLiteral(resultType, nullCheck = true)
val code =
s"""
|${operands.map(_.code).mkString("\n")}
|${nullValue.code}
|org.apache.flink.util.ExceptionUtils.rethrow(
| new RuntimeException(${operands.head.resultTerm}.toString()));
|""".stripMargin
GeneratedExpression(nullValue.resultTerm, nullValue.nullTerm, code, resultType)
case ssf: ScalarSqlFunction =>
new ScalarFunctionCallGen(
ssf.makeFunction(getOperandLiterals(operands), operands.map(_.resultType).toArray))
.generate(ctx, operands, resultType)
case tsf: TableSqlFunction =>
new TableFunctionCallGen(
call,
tsf.makeFunction(getOperandLiterals(operands), operands.map(_.resultType).toArray))
.generate(ctx, operands, resultType)
case _: BridgingSqlFunction =>
new BridgingSqlFunctionCallGen(call).generate(ctx, operands, resultType)
// advanced scalar functions
case sqlOperator: SqlOperator =>
StringCallGen.generateCallExpression(ctx, call.getOperator, operands, resultType)
.getOrElse {
FunctionGenerator
.getCallGenerator(
sqlOperator,
operands.map(expr => expr.resultType),
resultType)
.getOrElse(
throw new CodeGenException(s"Unsupported call: " +
s"$sqlOperator(${operands.map(_.resultType).mkString(", ")}) \n" +
s"If you think this function should be supported, " +
s"you can create an issue and start a discussion for it."))
.generate(ctx, operands, resultType)
}
// unknown or invalid
case call@_ =>
val explainCall = s"$call(${operands.map(_.resultType).mkString(", ")})"
throw new CodeGenException(s"Unsupported call: $explainCall")
}
}