public Node apply()

in runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/graph/RegisterNodeFunction.java [160:467]


  public Node apply(MutableNetwork<Node, Edge> input) {
    for (Node node : input.nodes()) {
      if (node instanceof RemoteGrpcPortNode
          || node instanceof ParallelInstructionNode
          || node instanceof InstructionOutputNode) {
        continue;
      }
      throw new IllegalArgumentException(
          String.format("Network contains unknown type of node: %s", input));
    }

    // Fix all non output nodes to have named edges.
    for (Node node : input.nodes()) {
      if (node instanceof InstructionOutputNode) {
        continue;
      }
      for (Node successor : input.successors(node)) {
        for (Edge edge : input.edgesConnecting(node, successor)) {
          if (edge instanceof DefaultEdge) {
            input.removeEdge(edge);
            input.addEdge(
                node,
                successor,
                MultiOutputInfoEdge.create(new MultiOutputInfo().setTag(idGenerator.getId())));
          }
        }
      }
    }

    // We start off by replacing all edges within the graph with edges that have the named
    // outputs from the predecessor step. For ParallelInstruction Source nodes and RemoteGrpcPort
    // nodes this is a generated port id. All ParDoInstructions will have already

    ProcessBundleDescriptor.Builder processBundleDescriptor =
        ProcessBundleDescriptor.newBuilder()
            .setId(idGenerator.getId())
            .setStateApiServiceDescriptor(stateApiServiceDescriptor);

    // For intermediate PCollections we fabricate, we make a bogus WindowingStrategy
    // TODO: create a correct windowing strategy, including coders and environment
    SdkComponents sdkComponents = SdkComponents.create(pipeline.getComponents(), null);

    // Default to use the Java environment if pipeline doesn't have environment specified.
    if (pipeline.getComponents().getEnvironmentsMap().isEmpty()) {
      sdkComponents.registerEnvironment(Environments.JAVA_SDK_HARNESS_ENVIRONMENT);
    }

    String fakeWindowingStrategyId = "fakeWindowingStrategy" + idGenerator.getId();
    try {
      RunnerApi.MessageWithComponents fakeWindowingStrategyProto =
          WindowingStrategyTranslation.toMessageProto(
              WindowingStrategy.globalDefault(), sdkComponents);
      processBundleDescriptor
          .putWindowingStrategies(
              fakeWindowingStrategyId, fakeWindowingStrategyProto.getWindowingStrategy())
          .putAllCoders(fakeWindowingStrategyProto.getComponents().getCodersMap())
          .putAllEnvironments(fakeWindowingStrategyProto.getComponents().getEnvironmentsMap());
    } catch (IOException exc) {
      throw new RuntimeException("Could not convert default windowing stratey to proto", exc);
    }

    Map<Node, String> nodesToPCollections = new HashMap<>();
    ImmutableMap.Builder<String, NameContext> ptransformIdToNameContexts = ImmutableMap.builder();
    ImmutableMap.Builder<String, Iterable<SideInputInfo>> ptransformIdToSideInputInfos =
        ImmutableMap.builder();
    ImmutableMap.Builder<String, Iterable<PCollectionView<?>>> ptransformIdToPCollectionViews =
        ImmutableMap.builder();
    ImmutableMap.Builder<String, NameContext> pcollectionIdToNameContexts = ImmutableMap.builder();
    ImmutableMap.Builder<InstructionOutputNode, String> instructionOutputNodeToCoderIdBuilder =
        ImmutableMap.builder();

    // For each instruction output node:
    // 1. Generate new Coder and register it with SDKComponents and ProcessBundleDescriptor.
    // 2. Generate new PCollectionId and register it with ProcessBundleDescriptor.
    for (InstructionOutputNode node :
        Iterables.filter(input.nodes(), InstructionOutputNode.class)) {
      InstructionOutput instructionOutput = node.getInstructionOutput();

      String coderId = "generatedCoder" + idGenerator.getId();
      instructionOutputNodeToCoderIdBuilder.put(node, coderId);
      try (ByteString.Output output = ByteString.newOutput()) {
        try {
          Coder<?> javaCoder =
              CloudObjects.coderFromCloudObject(CloudObject.fromSpec(instructionOutput.getCodec()));
          sdkComponents.registerCoder(javaCoder);
          RunnerApi.Coder coderProto = CoderTranslation.toProto(javaCoder, sdkComponents);
          processBundleDescriptor.putCoders(coderId, coderProto);
        } catch (IOException e) {
          throw new IllegalArgumentException(
              String.format(
                  "Unable to encode coder %s for output %s",
                  instructionOutput.getCodec(), instructionOutput),
              e);
        } catch (Exception e) {
          // Coder probably wasn't a java coder
          OBJECT_MAPPER.writeValue(output, instructionOutput.getCodec());
          processBundleDescriptor.putCoders(
              coderId,
              RunnerApi.Coder.newBuilder()
                  .setSpec(RunnerApi.FunctionSpec.newBuilder().setPayload(output.toByteString()))
                  .build());
        }
      } catch (IOException e) {
        throw new IllegalArgumentException(
            String.format(
                "Unable to encode coder %s for output %s",
                instructionOutput.getCodec(), instructionOutput),
            e);
      }

      // Generate new PCollection ID and map it to relevant node.
      // Will later be used to fill PTransform inputs/outputs information.
      String pcollectionId = "generatedPcollection" + idGenerator.getId();
      processBundleDescriptor.putPcollections(
          pcollectionId,
          RunnerApi.PCollection.newBuilder()
              .setCoderId(coderId)
              .setWindowingStrategyId(fakeWindowingStrategyId)
              .build());
      nodesToPCollections.put(node, pcollectionId);
      pcollectionIdToNameContexts.put(
          pcollectionId,
          NameContext.create(
              null,
              instructionOutput.getOriginalName(),
              instructionOutput.getSystemName(),
              instructionOutput.getName()));
    }
    processBundleDescriptor.putAllCoders(sdkComponents.toComponents().getCodersMap());
    Map<InstructionOutputNode, String> instructionOutputNodeToCoderIdMap =
        instructionOutputNodeToCoderIdBuilder.build();

    for (ParallelInstructionNode node :
        Iterables.filter(input.nodes(), ParallelInstructionNode.class)) {
      ParallelInstruction parallelInstruction = node.getParallelInstruction();
      String ptransformId = "generatedPtransform" + idGenerator.getId();
      ptransformIdToNameContexts.put(
          ptransformId,
          NameContext.create(
              null,
              parallelInstruction.getOriginalName(),
              parallelInstruction.getSystemName(),
              parallelInstruction.getName()));

      RunnerApi.PTransform.Builder pTransform = RunnerApi.PTransform.newBuilder();
      RunnerApi.FunctionSpec.Builder transformSpec = RunnerApi.FunctionSpec.newBuilder();

      if (parallelInstruction.getParDo() != null) {
        ParDoInstruction parDoInstruction = parallelInstruction.getParDo();
        CloudObject userFnSpec = CloudObject.fromSpec(parDoInstruction.getUserFn());
        String userFnClassName = userFnSpec.getClassName();

        if ("CombineValuesFn".equals(userFnClassName) || "KeyedCombineFn".equals(userFnClassName)) {
          transformSpec = transformCombineValuesFnToFunctionSpec(userFnSpec);
          ptransformIdToPCollectionViews.put(ptransformId, Collections.emptyList());
        } else {
          String parDoPTransformId = getString(userFnSpec, PropertyNames.SERIALIZED_FN);

          RunnerApi.PTransform parDoPTransform =
              pipeline.getComponents().getTransformsOrDefault(parDoPTransformId, null);

          // TODO: only the non-null branch should exist; for migration ease only
          if (parDoPTransform != null) {
            checkArgument(
                parDoPTransform
                    .getSpec()
                    .getUrn()
                    .equals(PTransformTranslation.PAR_DO_TRANSFORM_URN),
                "Found transform \"%s\" for ParallelDo instruction, "
                    + " but that transform had unexpected URN \"%s\" (expected \"%s\")",
                parDoPTransformId,
                parDoPTransform.getSpec().getUrn(),
                PTransformTranslation.PAR_DO_TRANSFORM_URN);

            RunnerApi.ParDoPayload parDoPayload;
            try {
              parDoPayload =
                  RunnerApi.ParDoPayload.parseFrom(parDoPTransform.getSpec().getPayload());
            } catch (InvalidProtocolBufferException exc) {
              throw new RuntimeException("ParDo did not have a ParDoPayload", exc);
            }

            ImmutableList.Builder<PCollectionView<?>> pcollectionViews = ImmutableList.builder();
            for (Map.Entry<String, SideInput> sideInputEntry :
                parDoPayload.getSideInputsMap().entrySet()) {
              pcollectionViews.add(
                  transformSideInputForRunner(
                      pipeline,
                      parDoPTransform,
                      sideInputEntry.getKey(),
                      sideInputEntry.getValue()));
              transformSideInputForSdk(
                  pipeline,
                  parDoPTransform,
                  sideInputEntry.getKey(),
                  processBundleDescriptor,
                  pTransform);
            }
            ptransformIdToPCollectionViews.put(ptransformId, pcollectionViews.build());

            transformSpec
                .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN)
                .setPayload(parDoPayload.toByteString());
          } else {
            // legacy path - bytes are the FunctionSpec's payload field, basically, and
            // SDKs expect it in the PTransform's payload field
            byte[] userFnBytes = getBytes(userFnSpec, PropertyNames.SERIALIZED_FN);
            transformSpec
                .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN)
                .setPayload(ByteString.copyFrom(userFnBytes));
          }

          // Add side input information for batch pipelines
          if (parDoInstruction.getSideInputs() != null) {
            ptransformIdToSideInputInfos.put(
                ptransformId, forSideInputInfos(parDoInstruction.getSideInputs(), true));
          }
        }
      } else if (parallelInstruction.getRead() != null) {
        ReadInstruction readInstruction = parallelInstruction.getRead();
        CloudObject sourceSpec =
            CloudObject.fromSpec(
                CloudSourceUtils.flattenBaseSpecs(readInstruction.getSource()).getSpec());
        // TODO: Need to plumb through the SDK specific function spec.
        transformSpec.setUrn(JAVA_SOURCE_URN);
        try {
          byte[] serializedSource =
              Base64.getDecoder().decode(getString(sourceSpec, SERIALIZED_SOURCE));
          ByteString sourceByteString = ByteString.copyFrom(serializedSource);
          transformSpec.setPayload(sourceByteString);
        } catch (Exception e) {
          throw new IllegalArgumentException(
              String.format("Unable to process Read %s", parallelInstruction), e);
        }
      } else if (parallelInstruction.getFlatten() != null) {
        transformSpec.setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN);
      } else {
        throw new IllegalArgumentException(
            String.format("Unknown type of ParallelInstruction %s", parallelInstruction));
      }

      for (Node predecessorOutput : input.predecessors(node)) {
        pTransform.putInputs(
            "generatedInput" + idGenerator.getId(), nodesToPCollections.get(predecessorOutput));
      }

      for (Edge edge : input.outEdges(node)) {
        Node nodeOutput = input.incidentNodes(edge).target();
        MultiOutputInfoEdge edge2 = (MultiOutputInfoEdge) edge;
        pTransform.putOutputs(
            edge2.getMultiOutputInfo().getTag(), nodesToPCollections.get(nodeOutput));
      }

      pTransform.setSpec(transformSpec);
      processBundleDescriptor.putTransforms(ptransformId, pTransform.build());
    }

    // Add the PTransforms representing the remote gRPC nodes
    for (RemoteGrpcPortNode node : Iterables.filter(input.nodes(), RemoteGrpcPortNode.class)) {

      RunnerApi.PTransform.Builder pTransform = RunnerApi.PTransform.newBuilder();

      Set<Node> predecessors = input.predecessors(node);
      Set<Node> successors = input.successors(node);
      if (predecessors.isEmpty() && !successors.isEmpty()) {
        Node instructionOutputNode = Iterables.getOnlyElement(successors);
        pTransform.putOutputs(
            "generatedOutput" + idGenerator.getId(),
            nodesToPCollections.get(instructionOutputNode));
        pTransform.setSpec(
            RunnerApi.FunctionSpec.newBuilder()
                .setUrn(DATA_INPUT_URN)
                .setPayload(
                    node.getRemoteGrpcPort()
                        .toBuilder()
                        .setCoderId(instructionOutputNodeToCoderIdMap.get(instructionOutputNode))
                        .build()
                        .toByteString())
                .build());
      } else if (!predecessors.isEmpty() && successors.isEmpty()) {
        Node instructionOutputNode = Iterables.getOnlyElement(predecessors);
        pTransform.putInputs(
            "generatedInput" + idGenerator.getId(), nodesToPCollections.get(instructionOutputNode));
        pTransform.setSpec(
            RunnerApi.FunctionSpec.newBuilder()
                .setUrn(DATA_OUTPUT_URN)
                .setPayload(
                    node.getRemoteGrpcPort()
                        .toBuilder()
                        .setCoderId(instructionOutputNodeToCoderIdMap.get(instructionOutputNode))
                        .build()
                        .toByteString())
                .build());
      } else {
        throw new IllegalStateException(
            "Expected either one input OR one output "
                + "InstructionOutputNode for this RemoteGrpcPortNode");
      }
      processBundleDescriptor.putTransforms(node.getPrimitiveTransformId(), pTransform.build());
    }

    return RegisterRequestNode.create(
        RegisterRequest.newBuilder().addProcessBundleDescriptor(processBundleDescriptor).build(),
        ptransformIdToNameContexts.build(),
        ptransformIdToSideInputInfos.build(),
        ptransformIdToPCollectionViews.build(),
        pcollectionIdToNameContexts.build());
  }