def get_artifact_type_class()

in tfx/types/artifact_utils.py [0:0]


def get_artifact_type_class(
    artifact_type: metadata_store_pb2.ArtifactType) -> Type[Artifact]:
  """Get the artifact type class corresponding to an MLMD type proto."""

  # Make sure this module path containing the standard Artifact subclass
  # definitions is imported. Modules containing custom artifact subclasses that
  # need to be deserialized should be imported by the entrypoint of the
  # application or container.
  from tfx.types import standard_artifacts  # pylint: disable=g-import-not-at-top,import-outside-toplevel,unused-import,unused-variable

  # Enumerate the Artifact type ontology, separated into auto-generated and
  # natively-defined classes.
  artifact_classes = _get_subclasses(Artifact)
  native_artifact_classes = []
  generated_artifact_classes = []
  value_artifact_classes = []
  for cls in artifact_classes:
    if not cls.TYPE_NAME:
      # Skip abstract classes.
      continue
    if getattr(cls, '_AUTOGENERATED', False):
      generated_artifact_classes.append(cls)
    else:
      native_artifact_classes.append(cls)
    if issubclass(cls, ValueArtifact):
      value_artifact_classes.append(cls)

  # Try to find an existing class for the artifact type, if it exists. Prefer
  # to use a native artifact class.
  for cls in itertools.chain(native_artifact_classes,
                             generated_artifact_classes):
    candidate_type = cls._get_artifact_type()  # pylint: disable=protected-access
    # We need to compare `.name` and `.properties` (and not the entire proto
    # directly), because the proto `.id` field will be populated when the type
    # is read from MLMD.
    if (artifact_type.name == candidate_type.name and
        artifact_type.properties == candidate_type.properties):
      return cls

  # Generate a class for the artifact type on the fly.
  logging.warning(
      'Could not find matching artifact class for type %r (proto: %r); '
      'generating an ephemeral artifact class on-the-fly. If this is not '
      'intended, please make sure that the artifact class for this type can '
      'be imported within your container or environment where a component '
      'is executed to consume this type.', artifact_type.name,
      str(artifact_type))

  for cls in value_artifact_classes:
    if not cls.TYPE_NAME:
      continue
    if artifact_type.name.startswith(cls.TYPE_NAME):
      new_artifact_class = _ValueArtifactType(
          mlmd_artifact_type=artifact_type, base=cls)
      setattr(new_artifact_class, '_AUTOGENERATED', True)
      return new_artifact_class

  new_artifact_class = _ArtifactType(mlmd_artifact_type=artifact_type)
  setattr(new_artifact_class, '_AUTOGENERATED', True)
  return new_artifact_class