def __call__()

in tfx/dsl/input_resolution/resolver_function.py [0:0]


  def __call__(self, *args, **kwargs):
    """Invoke a resolver function.

    This would trace the @resolver_function with given arguments. BaseChannel
    argument is converted to InputNode for tracing. Return value depends on the
    actual return value of the @resolver_function.

    * If the function returns ARTIFACT_LIST type, __call__ returns a BaseChannel
    instance that can be used as a component inputs.
    * If the function returns ARTIFACT_MULTIMAP type, __call__ returns a
    Mapping[str, BaseChannel].
    * If the function returns ARTIFACT_MULTIMAP_LIST, then __call__ returns a
    intermediate object that can be unwrapped to Mapping[str, BaseChannel] with
    ForEach context manager.

    Args:
      *args: Arguments to the wrapped function.
      **kwargs: Keyword arguments to the wrapped function.

    Raises:
      RuntimeError: if output_type is invalid or unset.

    Returns:
      Resolver function result as a BaseChannels.
    """
    output_type = self._output_type or (
        self._output_type_inferrer(*args, **kwargs)
    )
    if output_type is None:
      raise RuntimeError(
          'Unable to infer output type. Please use '
          'resolver_function.with_output_type()'
      )

    args = [self._try_convert_to_node(v) for v in args]
    kwargs = {k: self._try_convert_to_node(v) for k, v in kwargs.items()}
    out = self.trace(*args, **kwargs)

    invocation = self._invocation or resolved_channel.Invocation(
        function=self, args=args, kwargs=kwargs
    )

    if out.output_data_type == resolver_op.DataType.ARTIFACT_LIST:
      if self._loopable_transform is not None:
        raise TypeError(
            'loopable_transform is not applicable for ARTIFACT_LIST output'
        )
      if not typing_utils.is_compatible(output_type, _ArtifactType):
        raise RuntimeError(
            f'Invalid output_type {output_type}. Expected {_ArtifactType}'
        )
      output_type = cast(_ArtifactType, output_type)
      return resolved_channel.ResolvedChannel(
          artifact_type=output_type, output_node=out, invocation=invocation
      )
    if out.output_data_type == resolver_op.DataType.ARTIFACT_MULTIMAP:
      if self._loopable_transform is not None:
        raise TypeError(
            'loopable_transform is not applicable for ARTIFACT_MULTIMAP output'
        )
      if not typing_utils.is_compatible(output_type, _ArtifactTypeMap):
        raise RuntimeError(
            f'Invalid output_type {output_type}. Expected {_ArtifactTypeMap}'
        )
      output_type = cast(_ArtifactTypeMap, output_type)
      result = {}
      for key, artifact_type in output_type.items():
        result[key] = resolved_channel.ResolvedChannel(
            artifact_type=artifact_type,
            output_node=out,
            output_key=key,
            invocation=invocation,
        )
      return result
    if out.output_data_type == resolver_op.DataType.ARTIFACT_MULTIMAP_LIST:
      if not typing_utils.is_compatible(output_type, _ArtifactTypeMap):
        raise RuntimeError(
            f'Invalid output_type {output_type}. Expected {_ArtifactTypeMap}'
        )

      def loop_var_factory(for_each_context: for_each_internal.ForEachContext):
        result = {}
        for key, artifact_type in output_type.items():
          result[key] = resolved_channel.ResolvedChannel(
              artifact_type=artifact_type,
              output_node=out,
              output_key=key,
              invocation=invocation,
              for_each_context=for_each_context,
          )
        if self._loopable_transform:
          result = self._loopable_transform(result)
        return result

      return for_each_internal.Loopable(loop_var_factory)