def testUpdateExecution()

in tfx/orchestration/metadata_test_utils.py [0:0]


  def testUpdateExecution(self):
    with self.metadata() as m:
      contexts = m.register_pipeline_contexts_if_not_exists(self._pipeline_info)
      m.register_execution(
          input_artifacts={},
          exec_properties={'k': 'v1'},
          pipeline_info=self._pipeline_info,
          component_info=self._component_info,
          contexts=contexts)
      [execution] = m.store.get_executions_by_context(
          m.get_component_run_context(self._component_info).id)
      self.assertEqual(execution.properties['k'].string_value, 'v1')
      self.assertEqual(execution.properties['state'].string_value,
                       metadata.EXECUTION_STATE_NEW)
      self.assertEqual(execution.last_known_state,
                       metadata_store_pb2.Execution.RUNNING)

      m.update_execution(
          execution,
          self._component_info,
          input_artifacts={'input_a': [standard_artifacts.Examples()]},
          exec_properties={'k': 'v2'},
          contexts=contexts)

      [execution] = m.store.get_executions_by_context(
          m.get_component_run_context(self._component_info).id)
      self.assertEqual(execution.properties['k'].string_value, 'v2')
      self.assertEqual(execution.properties['state'].string_value,
                       metadata.EXECUTION_STATE_NEW)
      self.assertEqual(execution.last_known_state,
                       metadata_store_pb2.Execution.RUNNING)
      [event] = m.store.get_events_by_execution_ids([execution.id])
      self.assertEqual(event.artifact_id, 1)
      [artifact] = m.store.get_artifacts_by_context(
          m.get_component_run_context(self._component_info).id)
      self.assertEqual(artifact.id, 1)

      aa = standard_artifacts.Examples()
      aa.set_mlmd_artifact(artifact)
      m.update_execution(
          execution, self._component_info, input_artifacts={'input_a': [aa]})
      [event] = m.store.get_events_by_execution_ids([execution.id])
      self.assertEqual(event.type, metadata_store_pb2.Event.INPUT)

      m.publish_execution(
          self._component_info,
          output_artifacts={'output': [standard_artifacts.Model()]},
          exec_properties={'k': 'v3'})

      [execution] = m.store.get_executions_by_context(
          m.get_component_run_context(self._component_info).id)
      self.assertEqual(execution.properties['k'].string_value, 'v3')
      self.assertEqual(execution.properties['state'].string_value,
                       metadata.EXECUTION_STATE_COMPLETE)
      self.assertEqual(execution.last_known_state,
                       metadata_store_pb2.Execution.COMPLETE)
      events = m.store.get_events_by_execution_ids([execution.id])
      self.assertLen(events, 2)
      [event_b] = (
          e for e in events if e.type == metadata_store_pb2.Event.OUTPUT)
      self.assertEqual(event_b.artifact_id, 2)
      artifacts = m.store.get_artifacts_by_context(
          m.get_component_run_context(self._component_info).id)
      self.assertLen(artifacts, 2)
      [artifact_b] = (a for a in artifacts if a.id == 2)
      self.assertEqual(artifact_b.state, metadata_store_pb2.Artifact.LIVE)