in tfx/dsl/components/common/resolver.py [0:0]
def __init__(self,
strategy_class: Optional[Type[ResolverStrategy]] = None,
config: Optional[Dict[str, json_utils.JsonableType]] = None,
**channels: types.BaseChannel):
"""Init function for Resolver.
Args:
strategy_class: Optional `ResolverStrategy` which contains the artifact
resolution logic.
config: Optional dict of key to Jsonable type for constructing
resolver_strategy.
**channels: Input channels to the Resolver node as keyword arguments.
"""
if (strategy_class is not None and
not issubclass(strategy_class, ResolverStrategy)):
raise TypeError('strategy_class should be ResolverStrategy, but got '
f'{strategy_class} instead.')
if strategy_class is None and config is not None:
raise ValueError('Cannot use config parameter without strategy_class.')
for input_key, channel in channels.items():
if not isinstance(channel, channel_types.BaseChannel):
raise ValueError(f'Resolver got non-BaseChannel argument {input_key}.')
self._strategy_class = strategy_class
self._config = config or {}
# An observed inputs from DSL as if Resolver node takes an inputs.
# TODO(b/246907396): Remove raw_inputs usage.
self._raw_inputs = dict(channels)
if strategy_class is not None:
output_node = resolver_op.OpNode(
op_type=strategy_class,
output_data_type=resolver_op.DataType.ARTIFACT_MULTIMAP,
args=[
resolver_op.DictNode({
input_key: resolver_op.InputNode(channel)
for input_key, channel in channels.items()
})
],
kwargs=self._config)
self._input_dict = {
k: resolved_channel.ResolvedChannel(c.type, output_node, k)
for k, c in channels.items()
}
else:
self._input_dict = channels
self._output_dict = {
input_key: types.OutputChannel(channel.type, self, input_key)
for input_key, channel in channels.items()
}
super().__init__(driver_class=_ResolverDriver)