in scio-avro/src/main/scala/com/spotify/scio/avro/types/TypeProvider.scala [198:326]
private def schemaToType(
c: blackbox.Context
)(schema: Schema, annottees: Seq[c.Expr[Any]]): c.Expr[Any] = {
import c.universe._
checkMacroEnclosed(c)
// Returns: (raw type, e.g. Int, String, NestedRecord, nested case class definitions)
def getField(className: String, fieldName: String, fieldSchema: Schema): (Tree, Seq[Tree]) =
fieldSchema.getType match {
case UNION =>
val unionTypes = fieldSchema.getTypes.asScala.map(_.getType).distinct
if (unionTypes.size != 2 || !unionTypes.contains(NULL)) {
c.abort(
c.enclosingPosition,
s"type: ${fieldSchema.getType} is not supported. " +
s"Union type needs to contain exactly one 'null' type and one non null type."
)
}
val (field, recordClasses) =
getField(
className,
fieldName,
fieldSchema.getTypes.asScala.filter(_.getType != NULL).head
)
(tq"_root_.scala.Option[$field]", recordClasses)
case BOOLEAN =>
(tq"_root_.scala.Boolean", Nil)
case LONG =>
(tq"_root_.scala.Long", Nil)
case DOUBLE =>
(tq"_root_.scala.Double", Nil)
case INT =>
(tq"_root_.scala.Int", Nil)
case FLOAT =>
(tq"_root_.scala.Float", Nil)
case STRING | ENUM =>
(tq"_root_.java.lang.String", Nil)
case BYTES =>
(tq"_root_.com.google.protobuf.ByteString", Nil)
case ARRAY =>
val (field, generatedCaseClasses) =
getField(className, fieldName, fieldSchema.getElementType)
(tq"_root_.scala.List[$field]", generatedCaseClasses)
case MAP =>
val (fieldType, recordCaseClasses) =
getField(className, fieldName, fieldSchema.getValueType)
(tq"_root_.scala.collection.Map[_root_.java.lang.String,$fieldType]", recordCaseClasses)
case RECORD =>
val nestedClassName = s"$className$$${fieldSchema.getName}"
val (fields, recordClasses) =
extractFields(nestedClassName, fieldSchema)
(
q"${Ident(TypeName(nestedClassName))}",
Seq(q"case class ${TypeName(nestedClassName)}(..$fields)") ++ recordClasses
)
case t =>
c.abort(c.enclosingPosition, s"type: $t not supported")
}
// Returns: ("fieldName: fieldType", nested case class definitions)
def extractField(
className: String,
fieldName: String,
fieldSchema: Schema
): (Tree, Seq[Tree]) = {
val (fieldType, recordClasses) =
getField(className, SchemaUtil.unescapeNameIfReserved(fieldName), fieldSchema)
fieldSchema.getType match {
case UNION =>
(q"val ${TermName(fieldName)}: $fieldType = None", recordClasses)
case _ =>
(q"${TermName(fieldName)}: $fieldType", recordClasses)
}
}
def extractFields(className: String, schema: Schema): (Seq[Tree], Seq[Tree]) = {
val f = schema.getFields.asScala
.map(f => extractField(className, f.name, f.schema))
val fields = f.map(_._1)
val recordClasses = f
.flatMap(_._2)
.groupBy(_.asInstanceOf[ClassDef].name)
// note that if there are conflicting definitions of a nested record type, the Avro schema
// parser itself will catch it before getting to this step.
.map { case (_, cDefs) => cDefs.head } // Don't generate duplicate case classes
(fields.toSeq, recordClasses.toSeq)
}
val r = annottees.map(_.tree) match {
case l @ List(
q"$mods class $name[..$_] $_(..$cfields) extends { ..$_ } with ..$parents { $_ => ..$_ }"
) if mods.asInstanceOf[Modifiers].flags == NoFlags =>
if (parents.map(_.toString()).toSet != Set("scala.AnyRef")) {
c.abort(c.enclosingPosition, s"Invalid annotation, don't extend the case class $l")
}
if (cfields.nonEmpty) {
c.abort(c.enclosingPosition, s"Invalid annotation, don't provide class fields $l")
}
val (fields, recordClasses) = extractFields(name.toString, schema)
val docs = getRecordDocs(c)(l)
val docMethod = docs.headOption
.map(d => q"override def doc: _root_.java.lang.String = $d")
.toSeq
val docTrait = docMethod
.map(_ => tq"${p(c, ScioAvroType)}.HasAvroDoc")
val schemaMethod = Seq(q"""override def schema: ${p(c, ApacheAvro)}.Schema =
new ${p(c, ApacheAvro)}.Schema.Parser().parse(${schema.toString})""")
val caseClassTree = q"${caseClass(c)(mods, name, fields, Nil)}"
if (shouldDumpClassesForPlugin) {
dumpCodeForScalaPlugin(c)(recordClasses, caseClassTree, name.toString())
}
q"""$caseClassTree
${companion(c)(name, docTrait, schemaMethod ++ docMethod, fields)}
..$recordClasses
"""
case t => c.abort(c.enclosingPosition, s"Invalid annotation $t")
}
c.Expr[Any](r)
}