message_flow/sagas/orchestration/saga_manager_impl.py (188 lines of code) (raw):

import logging import re from typing import Generic, TypeVar from ...commands.common import CommandReplyOutcome, ReplyMessageHeaders from ...commands.producer import CommandProducer from ...messaging.common import IMessage from ...messaging.consumer import IMessageConsumer from ...messaging.producer import MessageBuilder from ..common import SagaReplyHeaders # type: ignore from .saga import Saga from .saga_actions import SagaActions from .saga_command_producer import SagaCommandProducer from .saga_data_serde import SagaDataMapping, SagaDataSerde from .saga_definition import SagaDefinition from .saga_instance import SagaInstance from .saga_instance_repository import ISagaInstanceRepository __all__ = ["SagaManagerImpl"] Data = TypeVar("Data") class SagaManagerImpl(Generic[Data]): def __init__( self, saga: Saga[Data], saga_instance_repository: ISagaInstanceRepository, command_producer: CommandProducer, message_consumer: IMessageConsumer, saga_command_producer: SagaCommandProducer, saga_data_mapping: SagaDataMapping, ) -> None: self._logger = logging.getLogger(self.__class__.__name__) self._saga = saga self._saga_instance_repository = saga_instance_repository self._command_producer = command_producer self._message_consumer = message_consumer self._saga_command_producer = saga_command_producer self._saga_data_mapping = saga_data_mapping @property def _saga_type(self) -> str: return self._saga.saga_type @property def _state_definition(self) -> SagaDefinition[Data]: sm = self._saga.saga_definition if sm is None: raise RuntimeError("state machine cannot be null") return sm def create(self, saga_data: Data) -> SagaInstance: saga_instance = SagaInstance( self._saga_type, "None", "????", "None", SagaDataSerde.serialize_saga_data(saga_data), ) saga_instance = self._saga_instance_repository.save(saga_instance) saga_id = saga_instance.saga_id self._saga.on_starting(saga_instance.saga_id, saga_data) actions: SagaActions[Data] = self._state_definition.start(saga_data) if actions.local_exception is not None: raise actions.local_exception self._process_actions( self._saga.saga_type, saga_id, saga_instance, saga_data, actions ) return saga_instance def subscribe_to_reply_channel(self) -> None: self._message_consumer.subscribe( {self._make_saga_reply_channel()}, self.handle_message, queue=self._make_saga_reply_queue(), ) def handle_message(self, message: IMessage) -> None: self._logger.debug("Handle message invoked %s", message) if message.has_header(SagaReplyHeaders.REPLY_SAGA_ID): self._handle_reply(message) else: self._logger.warning( "Handle message doesn't know what to do with: %s ", message ) def _handle_reply(self, message: IMessage) -> None: if not self._is_reply_for_this_saga_type(message): return self._logger.debug("Handle reply %s", message) saga_id = message.get_required_header(SagaReplyHeaders.REPLY_SAGA_ID) saga_type = message.get_required_header(SagaReplyHeaders.REPLY_SAGA_TYPE) saga_instance = self._saga_instance_repository.find(saga_id) saga_data: Data = SagaDataSerde.deserialize_saga_data( saga_instance.serialized_saga_data, self._saga_data_mapping, ) current_state = saga_instance.state_name self._logger.info("Current state=%s", current_state) actions = self._state_definition.handle_reply( saga_type, saga_id, current_state, saga_data, message ) self._logger.info("Handled reply. Sending commands %s", actions.commands) self._process_actions(saga_type, saga_id, saga_instance, saga_data, actions) def _is_reply_for_this_saga_type(self, message: IMessage) -> bool: if ( reply_saga_type := message.get_header(SagaReplyHeaders.REPLY_SAGA_TYPE) ) is not None: return reply_saga_type == self._saga_type return False def _process_actions( self, saga_type: str, saga_id: str, saga_instance: SagaInstance, saga_data: Data, actions: SagaActions[Data], ) -> None: while True: if actions.local_exception is not None: actions = self._state_definition.handle_reply( saga_type, saga_id, actions.updated_state, actions.updated_saga_data, # type: ignore MessageBuilder.with_payload(b"{}") .with_header( ReplyMessageHeaders.REPLY_OUTCOME, CommandReplyOutcome.FAILURE.value, ) .with_header(ReplyMessageHeaders.REPLY_TYPE, "Failure") .build(), ) else: last_request_id: str = self._saga_command_producer.send_commands( self._saga_type, saga_id, actions.commands, self._make_saga_reply_channel(), ) saga_instance.last_request_id = last_request_id self._update_state(saga_instance, actions) saga_instance.serialized_saga_data = SagaDataSerde.serialize_saga_data( actions.updated_saga_data if actions.updated_saga_data else saga_data ) if actions.is_end_state: self._perform_end_state_actions( saga_id, saga_instance, actions.is_compensating, actions.is_failed, saga_data, ) self._saga_instance_repository.update(saga_instance) if not actions.is_local: break actions = self._state_definition.handle_reply( saga_type, saga_id, actions.updated_state, actions.updated_saga_data, # type: ignore MessageBuilder.with_payload(b"{}") .with_header( ReplyMessageHeaders.REPLY_OUTCOME, CommandReplyOutcome.SUCCESS.value, ) .with_header(ReplyMessageHeaders.REPLY_TYPE, "Success") .build(), ) def _make_saga_reply_channel(self) -> str: return f"{self._saga_type}-reply" def _make_saga_reply_queue(self) -> str: kebab_case_saga_type = re.sub("(?!^)([A-Z]+)", r"-\1", self._saga_type).lower() return f"{kebab_case_saga_type}-queue" def _update_state( self, saga_instance: SagaInstance, actions: SagaActions[Data] ) -> None: if actions.updated_state is not None: saga_instance.state_name = actions.updated_state saga_instance.end_state = actions.is_end_state saga_instance.compensating = actions.is_compensating saga_instance.failed = actions.is_failed def _perform_end_state_actions( self, saga_id: str, saga_instance: SagaInstance, compensating: bool, failed: bool, saga_data: Data, ) -> None: if failed: self._saga.on_saga_failed(saga_id, saga_data) if compensating: self._saga.on_saga_rolled_back(saga_id, saga_data) else: self._saga.on_saga_completed_successfully(saga_id, saga_data)