rules_jvm_export/jvm_export/jvm_assembly.bzl (246 lines of code) (raw):

VersionInfo = provider( doc="A singleton provider that contains the raw value of a build setting", fields={"value": "Version number"}, ) def _jvm_assembly_impl(ctx): # version = ctx.attr.version[VersionInfo].value class_jar = generate_class_jar(ctx, None) source_jar = _generate_source_jar(ctx) output_files = [class_jar] return [ DefaultInfo( files=depset(output_files), ), JavaInfo( output_jar=class_jar, compile_jar=class_jar, source_jar=source_jar, ), ] def parse_maven_coordinates(coordinates_string, enforce_version_template=True): """ Given a string containing a standard Maven coordinate (g:a:[p:[c:]]v), returns a Maven artifact map (see above). See also https://github.com/bazelbuild/rules_jvm_external/blob/4.3/specs.bzl """ parts = coordinates_string.split(":") group_id, artifact_id = parts[0:2] if len(parts) == 3: version = parts[2] result = struct(group_id=group_id, artifact_id=artifact_id, version=version) elif len(parts) == 4: packaging = parts[2] version = parts[3] result = struct(group_id=group_id, artifact_id=artifact_id, packaging=packaging, version=version) elif len(parts) == 5: packaging = parts[2] classifier = parts[3] version = parts[4] result = struct(group_id=group_id, artifact_id=artifact_id, packaging=packaging, classifier=classifier, version=version) else: fail("failed to parse '{}'".format(coordinates_string)) if enforce_version_template and version != "{pom_version}": fail("should assign {pom_version} as Maven version via `tags` attribute") return result def jar_assembler(ctx): script = ctx.actions.declare_file( "jvm-export/{}-jar-assembler.py".format(ctx.attr.name) ) ctx.actions.expand_template( template=ctx.file._jar_assembler, output=script, substitutions={ "$PYTHON_PATH": ctx.attr.python_path, }, ) return script def runtime_output_jar(target): if JavaInfo in target: if len(target[JavaInfo].runtime_output_jars) == 1: return target[JavaInfo].runtime_output_jars[0] elif len(target[JavaInfo].runtime_output_jars) == 2: for jar in target[JavaInfo].runtime_output_jars: if jar.path.endswith("_java.jar"): return jar fail( "expected size 1, or the file name ends with _java.jar, but runtime_output_jars in {} was {}".format( target, target[JavaInfo].runtime_output_jars ) ) else: fail( "expected size 1, but runtime_output_jars in {} was {}".format( target, target[JavaInfo].runtime_output_jars ) ) else: outputs = target[DefaultInfo].files.to_list() return outputs[0] def generate_class_jar(ctx, pom_file): target = ctx.attr.target maven_coordinates = parse_maven_coordinates(target[JarInfo].name) jar = runtime_output_jar(target) output_jar = ctx.actions.declare_file( "{}:{}.jar".format(maven_coordinates.group_id, maven_coordinates.artifact_id) ) class_jar_deps = [ dep.class_jar for dep in target[JarInfo].jar_infos.to_list() if dep.type == "jar" ] class_jar_paths = [jar.path] + [target.path for target in class_jar_deps] args = ( [ "--group_id=" + maven_coordinates.group_id, "--artifact_id=" + maven_coordinates.artifact_id, ] + ([pom_file] if pom_file else []) + [ "--output=" + output_jar.path, ] + class_jar_paths ) inputs = [jar] + ([pom_file] if pom_file else []) + class_jar_deps ctx.actions.run( executable=jar_assembler(ctx), inputs=inputs, outputs=[output_jar], arguments=args, ) return output_jar def _generate_source_jar(ctx): target = ctx.attr.target maven_coordinates = parse_maven_coordinates(target[JarInfo].name) srcjar = None if len(target[JavaInfo].source_jars) < 1: fail("Could not find source JAR to deploy in {}".format(target)) else: srcjar = target[JavaInfo].source_jars[0] output_jar = ctx.actions.declare_file( "{}:{}-sources.jar".format( maven_coordinates.group_id, maven_coordinates.artifact_id ) ) source_jar_deps = [ dep.source_jar for dep in target[JarInfo].jar_infos.to_list() if dep.type == "jar" and dep.source_jar ] source_jar_paths = [srcjar.path] + [target.path for target in source_jar_deps] ctx.actions.run( executable=jar_assembler(ctx), inputs=[srcjar] + source_jar_deps, outputs=[output_jar], arguments=[ "--output=" + output_jar.path, ] + source_jar_paths, ) return output_jar def find_maven_coordinates(target, tags): _TAG_KEY_MAVEN_COORDINATES = "maven_coordinates=" _TAG_KEY_JVM_MODULE = "jvm_module=" _TAG_KEY_JVM_VERSION = "jvm_version=" mod = None ver = None for tag in tags: if tag.startswith(_TAG_KEY_MAVEN_COORDINATES): coordinates = tag[len(_TAG_KEY_MAVEN_COORDINATES) :] return coordinates elif tag.startswith(_TAG_KEY_JVM_MODULE): mod = tag[len(_TAG_KEY_JVM_MODULE) :] elif tag.startswith(_TAG_KEY_JVM_VERSION): ver = tag[len(_TAG_KEY_JVM_VERSION) :] if mod and ver: return "{}:{}".format(mod, ver) JarInfo = provider( fields={ "name": "The name of a the JAR (Maven coordinates)", "deps": "Direct dependencies", "jar_infos": "The list of dependencies of this JAR. A dependency may be of two types, POM or JAR.", "neverlink": "Forward neverlink from target", }, ) def _aggregate_dependency_info_impl(target, ctx): tags = getattr(ctx.rule.attr, "tags", []) deps = getattr(ctx.rule.attr, "deps", []) runtime_deps = getattr(ctx.rule.attr, "runtime_deps", []) exports = getattr(ctx.rule.attr, "exports", []) deps_all = deps + exports + runtime_deps neverlink = getattr(ctx.rule.attr, "neverlink", False) maven_coordinates = find_maven_coordinates(target, tags) dependencies = [] # depend via POM if maven_coordinates: dependencies = [struct(type="pom", maven_coordinates=maven_coordinates)] # include runtime output jars elif (JavaInfo in target) and target[JavaInfo].runtime_output_jars: jars = target[JavaInfo].runtime_output_jars source_jars = target[JavaInfo].source_jars dependencies = [ struct( type="jar", class_jar=jar, source_jar=source_jar, ) for (jar, source_jar) in zip( jars, source_jars + [None] * (len(jars) - len(source_jars)) ) ] return JarInfo( name=maven_coordinates, deps=deps, jar_infos=depset( dependencies, transitive=[ # Filter transitive JARs from dependency that has maven coordinates # because those dependencies will already include the JARs as part # of their classpath depset( [ dep for dep in jar[JarInfo].jar_infos.to_list() if dep.type == "pom" ] ) if jar[JarInfo].name else jar[JarInfo].jar_infos for jar in deps_all ], ), neverlink=neverlink, ) aggregate_dependency_info = aspect( attr_aspects=[ "jars", "deps", "exports", "runtime_deps", ], doc="Collects the Maven coordinates of the given java_library, its direct dependencies, and its transitive dependencies", implementation=_aggregate_dependency_info_impl, provides=[JarInfo], ) jvm_assembly = rule( attrs={ "target": attr.label( mandatory=True, aspects=[ aggregate_dependency_info, ], doc="Java target for subsequent deployment", ), "python_path": attr.string( default="/usr/bin/env python", doc="Path to python command", ), "_jar_assembler": attr.label( default="@twitter_rules_jvm_export//jvm_export/support:jar_assembler.py", executable=True, allow_single_file=True, cfg="host", ), }, implementation=_jvm_assembly_impl, doc="Aggregated JVM target", )