def pack_per_key_combiners()

in sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py [0:0]


def pack_per_key_combiners(stages, context, can_pack=lambda s: True):
  # type: (Iterable[Stage], TransformContext, Callable[[str], bool]) -> Iterator[Stage]

  """Packs sibling CombinePerKey stages into a single CombinePerKey.

  If CombinePerKey stages have a common input, one input each, and one output
  each, pack the stages into a single stage that runs all CombinePerKeys and
  outputs resulting tuples to a new PCollection. A subsequent stage unpacks
  tuples from this PCollection and sends them to the original output
  PCollections.
  """
  class _UnpackFn(core.DoFn):
    """A DoFn that unpacks a packed to multiple tagged outputs.

    Example:
      tags = (T1, T2, ...)
      input = (K, (V1, V2, ...))
      output = TaggedOutput(T1, (K, V1)), TaggedOutput(T2, (K, V1)), ...
    """
    def __init__(self, tags):
      self._tags = tags

    def process(self, element):
      key, values = element
      return [
          core.pvalue.TaggedOutput(tag, (key, value)) for tag,
          value in zip(self._tags, values)
      ]

  def _get_fallback_coder_id():
    return context.add_or_get_coder_id(
        # passing None works here because there are no component coders
        coders.registry.get_coder(object).to_runner_api(None))  # type: ignore[arg-type]

  def _get_component_coder_id_from_kv_coder(coder, index):
    assert index < 2
    if coder.spec.urn == common_urns.coders.KV.urn and len(
        coder.component_coder_ids) == 2:
      return coder.component_coder_ids[index]
    return _get_fallback_coder_id()

  def _get_key_coder_id_from_kv_coder(coder):
    return _get_component_coder_id_from_kv_coder(coder, 0)

  def _get_value_coder_id_from_kv_coder(coder):
    return _get_component_coder_id_from_kv_coder(coder, 1)

  def _try_fuse_stages(a, b):
    if a.can_fuse(b, context):
      return a.fuse(b, context)
    else:
      raise ValueError

  # Partition stages by whether they are eligible for CombinePerKey packing
  # and group eligible CombinePerKey stages by parent and environment.
  def get_stage_key(stage):
    if (len(stage.transforms) == 1 and can_pack(stage.name) and
        stage.environment is not None and python_urns.PACKED_COMBINE_FN in
        context.components.environments[stage.environment].capabilities):
      transform = only_transform(stage.transforms)
      if (transform.spec.urn == common_urns.composites.COMBINE_PER_KEY.urn and
          len(transform.inputs) == 1 and len(transform.outputs) == 1):
        combine_payload = proto_utils.parse_Bytes(
            transform.spec.payload, beam_runner_api_pb2.CombinePayload)
        if combine_payload.combine_fn.urn == python_urns.PICKLED_COMBINE_FN:
          return (only_element(transform.inputs.values()), stage.environment)
    return None

  grouped_eligible_stages, ineligible_stages = _group_stages_by_key(
      stages, get_stage_key)
  for stage in ineligible_stages:
    yield stage

  for stage_key, packable_stages in grouped_eligible_stages.items():
    input_pcoll_id, _ = stage_key
    try:
      if not len(packable_stages) > 1:
        raise ValueError('Only one stage in this group: Skipping stage packing')
      # Fused stage is used as template and is not yielded.
      fused_stage = functools.reduce(_try_fuse_stages, packable_stages)
    except ValueError:
      # Skip packing stages in this group.
      # Yield the stages unmodified, and then continue to the next group.
      for stage in packable_stages:
        yield stage
      continue

    transforms = [only_transform(stage.transforms) for stage in packable_stages]
    combine_payloads = [
        proto_utils.parse_Bytes(
            transform.spec.payload, beam_runner_api_pb2.CombinePayload)
        for transform in transforms
    ]
    output_pcoll_ids = [
        only_element(transform.outputs.values()) for transform in transforms
    ]

    # Build accumulator coder for (acc1, acc2, ...)
    accumulator_coder_ids = [
        combine_payload.accumulator_coder_id
        for combine_payload in combine_payloads
    ]
    tuple_accumulator_coder_id = context.add_or_get_coder_id(
        beam_runner_api_pb2.Coder(
            spec=beam_runner_api_pb2.FunctionSpec(urn=python_urns.TUPLE_CODER),
            component_coder_ids=accumulator_coder_ids))

    # Build packed output coder for (key, (out1, out2, ...))
    input_kv_coder_id = context.components.pcollections[input_pcoll_id].coder_id
    key_coder_id = _get_key_coder_id_from_kv_coder(
        context.components.coders[input_kv_coder_id])
    output_kv_coder_ids = [
        context.components.pcollections[output_pcoll_id].coder_id
        for output_pcoll_id in output_pcoll_ids
    ]
    output_value_coder_ids = [
        _get_value_coder_id_from_kv_coder(
            context.components.coders[output_kv_coder_id])
        for output_kv_coder_id in output_kv_coder_ids
    ]
    pack_output_value_coder = beam_runner_api_pb2.Coder(
        spec=beam_runner_api_pb2.FunctionSpec(urn=python_urns.TUPLE_CODER),
        component_coder_ids=output_value_coder_ids)
    pack_output_value_coder_id = context.add_or_get_coder_id(
        pack_output_value_coder)
    pack_output_kv_coder = beam_runner_api_pb2.Coder(
        spec=beam_runner_api_pb2.FunctionSpec(urn=common_urns.coders.KV.urn),
        component_coder_ids=[key_coder_id, pack_output_value_coder_id])
    pack_output_kv_coder_id = context.add_or_get_coder_id(pack_output_kv_coder)

    pack_stage_name = _make_pack_name([stage.name for stage in packable_stages])
    pack_transform_name = _make_pack_name([
        only_transform(stage.transforms).unique_name
        for stage in packable_stages
    ])
    pack_pcoll_id = unique_name(context.components.pcollections, 'pcollection')
    input_pcoll = context.components.pcollections[input_pcoll_id]
    context.components.pcollections[pack_pcoll_id].CopyFrom(
        beam_runner_api_pb2.PCollection(
            unique_name=pack_transform_name + '/Pack.out',
            coder_id=pack_output_kv_coder_id,
            windowing_strategy_id=input_pcoll.windowing_strategy_id,
            is_bounded=input_pcoll.is_bounded))

    # Set up Pack stage.
    # TODO(BEAM-7746): classes that inherit from RunnerApiFn are expected to
    #  accept a PipelineContext for from_runner_api/to_runner_api.  Determine
    #  how to accomodate this.
    pack_combine_fn = combiners.SingleInputTupleCombineFn(
        *[
            core.CombineFn.from_runner_api(combine_payload.combine_fn, context)  # type: ignore[arg-type]
            for combine_payload in combine_payloads
        ]).to_runner_api(context)  # type: ignore[arg-type]
    pack_transform = beam_runner_api_pb2.PTransform(
        unique_name=pack_transform_name + '/Pack',
        spec=beam_runner_api_pb2.FunctionSpec(
            urn=common_urns.composites.COMBINE_PER_KEY.urn,
            payload=beam_runner_api_pb2.CombinePayload(
                combine_fn=pack_combine_fn,
                accumulator_coder_id=tuple_accumulator_coder_id).
            SerializeToString()),
        inputs={'in': input_pcoll_id},
        # 'None' single output key follows convention for CombinePerKey.
        outputs={'None': pack_pcoll_id},
        environment_id=fused_stage.environment)
    pack_stage = Stage(
        pack_stage_name + '/Pack', [pack_transform],
        downstream_side_inputs=fused_stage.downstream_side_inputs,
        must_follow=fused_stage.must_follow,
        parent=fused_stage.parent,
        environment=fused_stage.environment)
    yield pack_stage

    # Set up Unpack stage
    tags = [str(i) for i in range(len(output_pcoll_ids))]
    pickled_do_fn_data = pickler.dumps((_UnpackFn(tags), (), {}, [], None))
    unpack_transform = beam_runner_api_pb2.PTransform(
        unique_name=pack_transform_name + '/Unpack',
        spec=beam_runner_api_pb2.FunctionSpec(
            urn=common_urns.primitives.PAR_DO.urn,
            payload=beam_runner_api_pb2.ParDoPayload(
                do_fn=beam_runner_api_pb2.FunctionSpec(
                    urn=python_urns.PICKLED_DOFN_INFO,
                    payload=pickled_do_fn_data)).SerializeToString()),
        inputs={'in': pack_pcoll_id},
        outputs=dict(zip(tags, output_pcoll_ids)),
        environment_id=fused_stage.environment)
    unpack_stage = Stage(
        pack_stage_name + '/Unpack', [unpack_transform],
        downstream_side_inputs=fused_stage.downstream_side_inputs,
        must_follow=fused_stage.must_follow,
        parent=fused_stage.parent,
        environment=fused_stage.environment)
    yield unpack_stage