in tfx/types/artifact_utils.py [0:0]
def get_artifact_type_class(
artifact_type: metadata_store_pb2.ArtifactType) -> Type[Artifact]:
"""Get the artifact type class corresponding to an MLMD type proto."""
# Make sure this module path containing the standard Artifact subclass
# definitions is imported. Modules containing custom artifact subclasses that
# need to be deserialized should be imported by the entrypoint of the
# application or container.
from tfx.types import standard_artifacts # pylint: disable=g-import-not-at-top,import-outside-toplevel,unused-import,unused-variable
# Enumerate the Artifact type ontology, separated into auto-generated and
# natively-defined classes.
artifact_classes = _get_subclasses(Artifact)
native_artifact_classes = []
generated_artifact_classes = []
value_artifact_classes = []
for cls in artifact_classes:
if not cls.TYPE_NAME:
# Skip abstract classes.
continue
if getattr(cls, '_AUTOGENERATED', False):
generated_artifact_classes.append(cls)
else:
native_artifact_classes.append(cls)
if issubclass(cls, ValueArtifact):
value_artifact_classes.append(cls)
# Try to find an existing class for the artifact type, if it exists. Prefer
# to use a native artifact class.
for cls in itertools.chain(native_artifact_classes,
generated_artifact_classes):
candidate_type = cls._get_artifact_type() # pylint: disable=protected-access
# We need to compare `.name` and `.properties` (and not the entire proto
# directly), because the proto `.id` field will be populated when the type
# is read from MLMD.
if (artifact_type.name == candidate_type.name and
artifact_type.properties == candidate_type.properties):
return cls
# Generate a class for the artifact type on the fly.
logging.warning(
'Could not find matching artifact class for type %r (proto: %r); '
'generating an ephemeral artifact class on-the-fly. If this is not '
'intended, please make sure that the artifact class for this type can '
'be imported within your container or environment where a component '
'is executed to consume this type.', artifact_type.name,
str(artifact_type))
for cls in value_artifact_classes:
if not cls.TYPE_NAME:
continue
if artifact_type.name.startswith(cls.TYPE_NAME):
new_artifact_class = _ValueArtifactType(
mlmd_artifact_type=artifact_type, base=cls)
setattr(new_artifact_class, '_AUTOGENERATED', True)
return new_artifact_class
new_artifact_class = _ArtifactType(mlmd_artifact_type=artifact_type)
setattr(new_artifact_class, '_AUTOGENERATED', True)
return new_artifact_class