private void translateExecutableStage()

in runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java [612:810]


  private <InputT, OutputT> void translateExecutableStage(
      String id, RunnerApi.Pipeline pipeline, StreamingTranslationContext context) {
    // TODO: Fail on splittable DoFns.
    // TODO: Special-case single outputs to avoid multiplexing PCollections.
    RunnerApi.Components components = pipeline.getComponents();
    RunnerApi.PTransform transform = components.getTransformsOrThrow(id);
    Map<String, String> outputs = transform.getOutputsMap();

    final RunnerApi.ExecutableStagePayload stagePayload;
    try {
      stagePayload = RunnerApi.ExecutableStagePayload.parseFrom(transform.getSpec().getPayload());
    } catch (IOException e) {
      throw new RuntimeException(e);
    }

    String inputPCollectionId = stagePayload.getInput();
    final TransformedSideInputs transformedSideInputs;

    if (stagePayload.getSideInputsCount() > 0) {
      transformedSideInputs = transformSideInputs(stagePayload, components, context);
    } else {
      transformedSideInputs = new TransformedSideInputs(Collections.emptyMap(), null);
    }

    Map<TupleTag<?>, OutputTag<WindowedValue<?>>> tagsToOutputTags = Maps.newLinkedHashMap();
    Map<TupleTag<?>, Coder<WindowedValue<?>>> tagsToCoders = Maps.newLinkedHashMap();
    // TODO: does it matter which output we designate as "main"
    final TupleTag<OutputT> mainOutputTag =
        outputs.isEmpty() ? null : new TupleTag(outputs.keySet().iterator().next());

    // associate output tags with ids, output manager uses these Integer ids to serialize state
    BiMap<String, Integer> outputIndexMap = createOutputMap(outputs.keySet());
    Map<String, Coder<WindowedValue<?>>> outputCoders = Maps.newHashMap();
    Map<TupleTag<?>, Integer> tagsToIds = Maps.newHashMap();
    Map<String, TupleTag<?>> collectionIdToTupleTag = Maps.newHashMap();
    // order output names for deterministic mapping
    for (String localOutputName : new TreeMap<>(outputIndexMap).keySet()) {
      String collectionId = outputs.get(localOutputName);
      Coder<WindowedValue<?>> windowCoder = (Coder) instantiateCoder(collectionId, components);
      outputCoders.put(localOutputName, windowCoder);
      TupleTag<?> tupleTag = new TupleTag<>(localOutputName);
      CoderTypeInformation<WindowedValue<?>> typeInformation =
          new CoderTypeInformation(windowCoder, context.getPipelineOptions());
      tagsToOutputTags.put(tupleTag, new OutputTag<>(localOutputName, typeInformation));
      tagsToCoders.put(tupleTag, windowCoder);
      tagsToIds.put(tupleTag, outputIndexMap.get(localOutputName));
      collectionIdToTupleTag.put(collectionId, tupleTag);
    }

    final SingleOutputStreamOperator<WindowedValue<OutputT>> outputStream;
    DataStream<WindowedValue<InputT>> inputDataStream =
        context.getDataStreamOrThrow(inputPCollectionId);

    CoderTypeInformation<WindowedValue<OutputT>> outputTypeInformation =
        (!outputs.isEmpty())
            ? new CoderTypeInformation(
                outputCoders.get(mainOutputTag.getId()), context.getPipelineOptions())
            : null;

    ArrayList<TupleTag<?>> additionalOutputTags = Lists.newArrayList();
    for (TupleTag<?> tupleTag : tagsToCoders.keySet()) {
      if (!mainOutputTag.getId().equals(tupleTag.getId())) {
        additionalOutputTags.add(tupleTag);
      }
    }

    final Coder<WindowedValue<InputT>> windowedInputCoder =
        instantiateCoder(inputPCollectionId, components);

    final boolean stateful =
        stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0;
    final boolean hasSdfProcessFn =
        stagePayload.getComponents().getTransformsMap().values().stream()
            .anyMatch(
                pTransform ->
                    pTransform
                        .getSpec()
                        .getUrn()
                        .equals(
                            PTransformTranslation
                                .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN));
    Coder keyCoder = null;
    KeySelector<WindowedValue<InputT>, ?> keySelector = null;
    if (stateful || hasSdfProcessFn) {
      // Stateful/SDF stages are only allowed of KV input.
      Coder valueCoder =
          ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder();
      if (!(valueCoder instanceof KvCoder)) {
        throw new IllegalStateException(
            String.format(
                Locale.ENGLISH,
                "The element coder for stateful DoFn '%s' must be KvCoder but is: %s",
                inputPCollectionId,
                valueCoder.getClass().getSimpleName()));
      }
      if (stateful) {
        keyCoder = ((KvCoder) valueCoder).getKeyCoder();
        keySelector =
            new KvToByteBufferKeySelector(
                keyCoder, new SerializablePipelineOptions(context.getPipelineOptions()));
      } else {
        // For an SDF, we know that the input element should be
        // KV<KV<element, KV<restriction, watermarkState>>, size>. We are going to use the element
        // as the key.
        if (!(((KvCoder) valueCoder).getKeyCoder() instanceof KvCoder)) {
          throw new IllegalStateException(
              String.format(
                  Locale.ENGLISH,
                  "The element coder for splittable DoFn '%s' must be KVCoder(KvCoder, DoubleCoder) but is: %s",
                  inputPCollectionId,
                  valueCoder.getClass().getSimpleName()));
        }
        keyCoder = ((KvCoder) ((KvCoder) valueCoder).getKeyCoder()).getKeyCoder();
        keySelector =
            new SdfByteBufferKeySelector(
                keyCoder, new SerializablePipelineOptions(context.getPipelineOptions()));
      }
      inputDataStream = inputDataStream.keyBy(keySelector);
    }

    DoFnOperator.MultiOutputOutputManagerFactory<OutputT> outputManagerFactory =
        new DoFnOperator.MultiOutputOutputManagerFactory<>(
            mainOutputTag,
            tagsToOutputTags,
            tagsToCoders,
            tagsToIds,
            new SerializablePipelineOptions(context.getPipelineOptions()));

    DoFnOperator<InputT, OutputT> doFnOperator =
        new ExecutableStageDoFnOperator<>(
            transform.getUniqueName(),
            windowedInputCoder,
            Collections.emptyMap(),
            mainOutputTag,
            additionalOutputTags,
            outputManagerFactory,
            transformedSideInputs.unionTagToView,
            new ArrayList<>(transformedSideInputs.unionTagToView.values()),
            getSideInputIdToPCollectionViewMap(stagePayload, components),
            context.getPipelineOptions(),
            stagePayload,
            context.getJobInfo(),
            FlinkExecutableStageContextFactory.getInstance(),
            collectionIdToTupleTag,
            getWindowingStrategy(inputPCollectionId, components),
            keyCoder,
            keySelector);

    final String operatorName = generateNameFromStagePayload(stagePayload);

    if (transformedSideInputs.unionTagToView.isEmpty()) {
      outputStream = inputDataStream.transform(operatorName, outputTypeInformation, doFnOperator);
    } else {
      DataStream<RawUnionValue> sideInputStream =
          transformedSideInputs.unionedSideInputs.broadcast();
      if (stateful || hasSdfProcessFn) {
        // We have to manually construct the two-input transform because we're not
        // allowed to have only one input keyed, normally. Since Flink 1.5.0 it's
        // possible to use the Broadcast State Pattern which provides a more elegant
        // way to process keyed main input with broadcast state, but it's not feasible
        // here because it breaks the DoFnOperator abstraction.
        TwoInputTransformation<WindowedValue<KV<?, InputT>>, RawUnionValue, WindowedValue<OutputT>>
            rawFlinkTransform =
                new TwoInputTransformation(
                    inputDataStream.getTransformation(),
                    sideInputStream.getTransformation(),
                    transform.getUniqueName(),
                    doFnOperator,
                    outputTypeInformation,
                    inputDataStream.getParallelism());

        rawFlinkTransform.setStateKeyType(((KeyedStream) inputDataStream).getKeyType());
        rawFlinkTransform.setStateKeySelectors(
            ((KeyedStream) inputDataStream).getKeySelector(), null);

        outputStream =
            new SingleOutputStreamOperator(
                inputDataStream.getExecutionEnvironment(),
                rawFlinkTransform) {}; // we have to cheat around the ctor being protected
      } else {
        outputStream =
            inputDataStream
                .connect(sideInputStream)
                .transform(operatorName, outputTypeInformation, doFnOperator);
      }
    }
    // Assign a unique but consistent id to re-map operator state
    outputStream.uid(transform.getUniqueName());

    if (mainOutputTag != null) {
      context.addDataStream(outputs.get(mainOutputTag.getId()), outputStream);
    }

    for (TupleTag<?> tupleTag : additionalOutputTags) {
      context.addDataStream(
          outputs.get(tupleTag.getId()),
          outputStream.getSideOutput(tagsToOutputTags.get(tupleTag)));
    }
  }