in tfx/orchestration/metadata_test_utils.py [0:0]
def testUpdateExecution(self):
with self.metadata() as m:
contexts = m.register_pipeline_contexts_if_not_exists(self._pipeline_info)
m.register_execution(
input_artifacts={},
exec_properties={'k': 'v1'},
pipeline_info=self._pipeline_info,
component_info=self._component_info,
contexts=contexts)
[execution] = m.store.get_executions_by_context(
m.get_component_run_context(self._component_info).id)
self.assertEqual(execution.properties['k'].string_value, 'v1')
self.assertEqual(execution.properties['state'].string_value,
metadata.EXECUTION_STATE_NEW)
self.assertEqual(execution.last_known_state,
metadata_store_pb2.Execution.RUNNING)
m.update_execution(
execution,
self._component_info,
input_artifacts={'input_a': [standard_artifacts.Examples()]},
exec_properties={'k': 'v2'},
contexts=contexts)
[execution] = m.store.get_executions_by_context(
m.get_component_run_context(self._component_info).id)
self.assertEqual(execution.properties['k'].string_value, 'v2')
self.assertEqual(execution.properties['state'].string_value,
metadata.EXECUTION_STATE_NEW)
self.assertEqual(execution.last_known_state,
metadata_store_pb2.Execution.RUNNING)
[event] = m.store.get_events_by_execution_ids([execution.id])
self.assertEqual(event.artifact_id, 1)
[artifact] = m.store.get_artifacts_by_context(
m.get_component_run_context(self._component_info).id)
self.assertEqual(artifact.id, 1)
aa = standard_artifacts.Examples()
aa.set_mlmd_artifact(artifact)
m.update_execution(
execution, self._component_info, input_artifacts={'input_a': [aa]})
[event] = m.store.get_events_by_execution_ids([execution.id])
self.assertEqual(event.type, metadata_store_pb2.Event.INPUT)
m.publish_execution(
self._component_info,
output_artifacts={'output': [standard_artifacts.Model()]},
exec_properties={'k': 'v3'})
[execution] = m.store.get_executions_by_context(
m.get_component_run_context(self._component_info).id)
self.assertEqual(execution.properties['k'].string_value, 'v3')
self.assertEqual(execution.properties['state'].string_value,
metadata.EXECUTION_STATE_COMPLETE)
self.assertEqual(execution.last_known_state,
metadata_store_pb2.Execution.COMPLETE)
events = m.store.get_events_by_execution_ids([execution.id])
self.assertLen(events, 2)
[event_b] = (
e for e in events if e.type == metadata_store_pb2.Event.OUTPUT)
self.assertEqual(event_b.artifact_id, 2)
artifacts = m.store.get_artifacts_by_context(
m.get_component_run_context(self._component_info).id)
self.assertLen(artifacts, 2)
[artifact_b] = (a for a in artifacts if a.id == 2)
self.assertEqual(artifact_b.state, metadata_store_pb2.Artifact.LIVE)