in tfx/dsl/input_resolution/resolver_function.py [0:0]
def __call__(self, *args, **kwargs):
"""Invoke a resolver function.
This would trace the @resolver_function with given arguments. BaseChannel
argument is converted to InputNode for tracing. Return value depends on the
actual return value of the @resolver_function.
* If the function returns ARTIFACT_LIST type, __call__ returns a BaseChannel
instance that can be used as a component inputs.
* If the function returns ARTIFACT_MULTIMAP type, __call__ returns a
Mapping[str, BaseChannel].
* If the function returns ARTIFACT_MULTIMAP_LIST, then __call__ returns a
intermediate object that can be unwrapped to Mapping[str, BaseChannel] with
ForEach context manager.
Args:
*args: Arguments to the wrapped function.
**kwargs: Keyword arguments to the wrapped function.
Raises:
RuntimeError: if output_type is invalid or unset.
Returns:
Resolver function result as a BaseChannels.
"""
output_type = self._output_type or (
self._output_type_inferrer(*args, **kwargs)
)
if output_type is None:
raise RuntimeError(
'Unable to infer output type. Please use '
'resolver_function.with_output_type()'
)
args = [self._try_convert_to_node(v) for v in args]
kwargs = {k: self._try_convert_to_node(v) for k, v in kwargs.items()}
out = self.trace(*args, **kwargs)
invocation = self._invocation or resolved_channel.Invocation(
function=self, args=args, kwargs=kwargs
)
if out.output_data_type == resolver_op.DataType.ARTIFACT_LIST:
if self._loopable_transform is not None:
raise TypeError(
'loopable_transform is not applicable for ARTIFACT_LIST output'
)
if not typing_utils.is_compatible(output_type, _ArtifactType):
raise RuntimeError(
f'Invalid output_type {output_type}. Expected {_ArtifactType}'
)
output_type = cast(_ArtifactType, output_type)
return resolved_channel.ResolvedChannel(
artifact_type=output_type, output_node=out, invocation=invocation
)
if out.output_data_type == resolver_op.DataType.ARTIFACT_MULTIMAP:
if self._loopable_transform is not None:
raise TypeError(
'loopable_transform is not applicable for ARTIFACT_MULTIMAP output'
)
if not typing_utils.is_compatible(output_type, _ArtifactTypeMap):
raise RuntimeError(
f'Invalid output_type {output_type}. Expected {_ArtifactTypeMap}'
)
output_type = cast(_ArtifactTypeMap, output_type)
result = {}
for key, artifact_type in output_type.items():
result[key] = resolved_channel.ResolvedChannel(
artifact_type=artifact_type,
output_node=out,
output_key=key,
invocation=invocation,
)
return result
if out.output_data_type == resolver_op.DataType.ARTIFACT_MULTIMAP_LIST:
if not typing_utils.is_compatible(output_type, _ArtifactTypeMap):
raise RuntimeError(
f'Invalid output_type {output_type}. Expected {_ArtifactTypeMap}'
)
def loop_var_factory(for_each_context: for_each_internal.ForEachContext):
result = {}
for key, artifact_type in output_type.items():
result[key] = resolved_channel.ResolvedChannel(
artifact_type=artifact_type,
output_node=out,
output_key=key,
invocation=invocation,
for_each_context=for_each_context,
)
if self._loopable_transform:
result = self._loopable_transform(result)
return result
return for_each_internal.Loopable(loop_var_factory)