syndicate/connection/sns_connection.py (175 lines of code) (raw):

""" Copyright 2018 EPAM Systems, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import uuid from json import dumps, loads from boto3 import client from botocore.exceptions import ClientError from syndicate.commons.log_helper import get_logger from syndicate.connection.helper import apply_methods_decorator, retry _LOG = get_logger(__name__) @apply_methods_decorator(retry()) class SNSConnection(object): def __init__(self, region=None, aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None): self.client = client('sns', region, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token) _LOG.debug('Opened new SNS connection.') def create_topic(self, name, tags): """ Crete SNS topic and return topic arn. :type name: str :type tags: list """ params = dict( Name=name ) if tags: params['Tags'] = tags return self.client.create_topic(**params)['TopicArn'] def subscribe(self, endpoint, topic_name, protocol): """ :param protocol: http -- delivery of JSON-encoded message via HTTP POST https -- delivery of JSON-encoded message via HTTPS POST email -- delivery of message via SMTP email-json -- delivery of JSON-encoded message via SMTP sms -- delivery of message via SMS sqs -- delivery of JSON-encoded message to an Amazon SQS queue application -- delivery of JSON-encoded message to an EndpointArn for a mobile app and device. lambda -- delivery of JSON message to an AWS Lambda function :type protocol: str :type topic_name: str :param endpoint: http protocol, the endpoint is an URL beginning with "http://" https protocol, the endpoint is a URL beginning with "https://" email protocol, the endpoint is an email address email-json protocol, the endpoint is an email address sms protocol, the endpoint is a phone number of an SMS-enabled device sqs protocol, the endpoint is the ARN of an Amazon SQS queue application protocol, the endpoint is the EndpointArn of a mobile app and device. lambda protocol, the endpoint is the ARN of an AWS Lambda function :type endpoint: str """ topic_arn = self.get_topic_arn(topic_name) if topic_arn is None: raise AssertionError( 'Topic does not exist: {0}.'.format(topic_name)) self.client.subscribe(TopicArn=topic_arn, Protocol=protocol, Endpoint=endpoint) return topic_arn def get_topic_arn(self, name): """ Get topic arn by name. :type name: str """ topics = self.get_topics() for each in topics: if name == each['TopicArn'].split(':')[-1]: return each['TopicArn'] def get_platform_application(self, name): """ Get application arn by name. :type name: str """ applications = self.get_platform_applications() for each in applications: resolved_item = each['PlatformApplicationArn'].split(':')[-1] if name == resolved_item.split('/')[-1]: return each['PlatformApplicationArn'] def is_user_subscribed(self, endpoint, topic_name): topic_arn = self.get_topic_arn(topic_name) subscriptions = self.client.list_subscriptions_by_topic( TopicArn=topic_arn)['Subscriptions'] for each in subscriptions: if endpoint == each['Endpoint']: return True def publish_message(self, topic_name, message): topic_arn = self.get_topic_arn(topic_name) return self.client.publish( TargetArn=topic_arn, Message=message, MessageAttributes={ 'string': { 'DataType': 'String', 'StringValue': ' ' } } ) def get_topics(self): """ Get all topics.""" topics = [] response = self.client.list_topics() topics.extend(response.get('Topics')) token = response.get('NextToken') while token: response = self.client.list_topics(NextToken=token) topics.extend(response.get('Topics')) token = response.get('NextToken') return topics def get_platform_applications(self): """ Get all platform applications.""" applications = [] response = self.client.list_platform_applications() applications.extend(response.get('PlatformApplications')) token = response.get('NextToken') while token: response = self.client.list_platform_applications(NextToken=token) applications.extend(response.get('PlatformApplications')) token = response.get('NextToken') return applications def remove_topic_by_arn(self, topic_arn, log_not_found_error=True): """ Remove topic by arn. :type topic_arn: str :type log_not_found_error: boolean, parameter is needed for proper log handling in the retry decorator """ # make get api call first, because the delete function is idempotent if self.get_topic_attributes(topic_arn): self.client.delete_topic(TopicArn=topic_arn) def remove_topic_by_name(self, topic_name): """ Remove topic by arn. :type topic_name: str """ arn = self.get_topic_arn(topic_name) if arn: self.client.delete_topic(TopicArn=arn) def set_topic_attribute(self, topic_arn, attr_name, attr_value): self.client.set_topic_attributes( TopicArn=topic_arn, AttributeName=attr_name, AttributeValue=attr_value ) def allow_service_invoke(self, topic_arn, service): existing_attr = self.get_topic_attributes(topic_arn) existing_policy = existing_attr['Attributes']['Policy'] existing_policy_dict = loads(existing_policy) policy = { "Sid": str(uuid.uuid1()), "Effect": "Allow", "Principal": { "Service": "{0}".format(service) }, "Action": "sns:Publish", "Resource": "{0}".format(topic_arn) } existing_policy_dict['Statement'].append(policy) self.set_topic_attribute(topic_arn, 'Policy', dumps(existing_policy_dict)) def get_topic_attributes(self, topic_arn): return self.client.get_topic_attributes( TopicArn=topic_arn ) def get_platform_application_attributes(self, application_arn): return self.client.get_platform_application_attributes( PlatformApplicationArn=application_arn ) def add_account_permission(self, topic_arn, account_id, action, label): if isinstance(account_id, str): account_id = [account_id] if not isinstance(account_id, list): raise AssertionError('Incorrect account id {0}'.format(account_id)) if isinstance(action, str): action = [action] if not isinstance(action, list): raise AssertionError('Incorrect action {0}'.format(action)) self.client.add_permission(TopicArn=topic_arn, Label=label, AWSAccountId=account_id, ActionName=action) def revoke_account_permission(self, topic_arn, label): self.client.remove_permission(TopicArn=topic_arn, Label=label) def list_subscriptions_by_topic(self, topic_arn): subscriptions = [] try: response = self.client.list_subscriptions_by_topic( TopicArn=topic_arn) except ClientError as e: if e.response['Error']['Code'] == 'NotFound': _LOG.warn(f'SNS topic \'{topic_arn}\' is not found') return subscriptions else: raise e subscriptions.extend(response.get('Subscriptions')) token = response.get('NextToken') while token: response = self.client.list_subscriptions_by_topic( TopicArn=topic_arn, NextToken=token) subscriptions.extend(response.get('Subscriptions')) token = response.get('NextToken') return subscriptions def unsubscribe(self, subscription_arn): self.client.unsubscribe(SubscriptionArn=subscription_arn) def create_platform_endpoint(self, platform_application_arn, token): response = self.client.create_platform_endpoint( PlatformApplicationArn=platform_application_arn, Token=token ) return response.get('EndpointArn') def create_platform_application(self, name, platform, attributes): response = self.client.create_platform_application( Name=name, Platform=platform, Attributes=attributes) return response.get('PlatformApplicationArn') def remove_application_by_arn(self, application_arn, log_not_found_error=True): """ Remove application by arn. :type application_arn: str :type log_not_found_error boolean, parameter is needed for proper log handling in the retry decorator """ self.client.delete_platform_application( PlatformApplicationArn=application_arn) def list_subscriptions(self): paginator = self.client.get_paginator('list_subscriptions') subscriptions = [] for page in paginator.paginate(): subscriptions.extend(page['Subscriptions']) return subscriptions