syndicate/connection/appsync_connection.py (371 lines of code) (raw):

""" Copyright 2018 EPAM Systems, Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import time from boto3 import client from syndicate.commons.log_helper import get_logger from syndicate.connection.helper import apply_methods_decorator, retry from syndicate.core.helper import dict_keys_to_camel_case _LOG = get_logger(__name__) DATA_SOURCE_TYPE_CONFIG_MAPPING = { 'AWS_LAMBDA': 'lambdaConfig', 'AMAZON_DYNAMODB': 'dynamodbConfig', 'AMAZON_ELASTICSEARCH': 'elasticsearchConfig', 'HTTP': 'httpConfig', 'RELATIONAL_DATABASE': 'relationalDatabaseConfig', 'AMAZON_OPENSEARCH_SERVICE': 'openSearchServiceConfig', 'AMAZON_EVENTBRIDGE': 'eventBridgeConfig' } REDUNDANT_RESOLVER_EXCEPTION_TEXT = 'Only one resolver is allowed per field' DATA_SOURCE_EXISTS_EXCEPTION_TEXT = \ 'Data source with name {name} already exists' @apply_methods_decorator(retry()) class AppSyncConnection(object): """ AWS AppSync connection class. """ def __init__(self, region=None, aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None): self.client = client('appsync', region, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token) self.region = region _LOG.debug('Opened new AppSync connection.') # ------------------------ Create ------------------------ def create_api(self, name, event_config, tags, owner): params = dict( name=name ) if event_config: params['eventConfig'] = tags if tags: params['tags'] = tags if owner: params['ownerContact'] = owner return self.client.create_api(**params) def create_graphql_api(self, name: str, auth_type: str, tags: dict = None, user_pool_config: dict = None, open_id_config: dict = None, lambda_auth_config: dict = None, log_config: dict = None, xray_enabled: bool = None, extra_auth_types: list = None): params = dict( name=name, authenticationType=auth_type ) if tags: params['tags'] = tags if user_pool_config: user_pool_config = dict_keys_to_camel_case(user_pool_config) params['userPoolConfig'] = user_pool_config if open_id_config: open_id_config = dict_keys_to_camel_case(open_id_config) params['openIDConnectConfig'] = open_id_config if lambda_auth_config: lambda_auth_config = dict_keys_to_camel_case(lambda_auth_config) params['lambdaAuthorizerConfig'] = lambda_auth_config if log_config: params['logConfig'] = dict_keys_to_camel_case(log_config) if xray_enabled: params['xrayEnabled'] = xray_enabled if extra_auth_types: params['additionalAuthenticationProviders'] = extra_auth_types return self.client.create_graphql_api(**params)['graphqlApi']['apiId'] def create_schema(self, api_id: str, definition: str): response = self.client.start_schema_creation( apiId=api_id, definition=str.encode(definition) ) status = response['status'] details = response.get('details', '') while status == 'PROCESSING': time.sleep(2) response = self.client.get_schema_creation_status( apiId=api_id ) status = response['status'] details = response.get('details', '') return status, details def create_type(self, api_id: str, definition: str, format: str): params = dict( api_id=api_id, definition=definition, format=format ) return self.client.create_type(**params)['type'] def create_data_source(self, api_id: str, name: str, source_type: str, source_config: dict = None, description: str = None, service_role_arn: str = None): params = dict( apiId=api_id, name=name, type=source_type ) config_key = DATA_SOURCE_TYPE_CONFIG_MAPPING.get(source_type) if config_key: source_config = dict_keys_to_camel_case(source_config) params[config_key] = source_config if description: params['description'] = description if service_role_arn: params['serviceRoleArn'] = service_role_arn try: return self.client.create_data_source(**params)['dataSource'] except self.client.exceptions.BadRequestException as e: error_text = DATA_SOURCE_EXISTS_EXCEPTION_TEXT.format(name=name) if error_text in str(e): _LOG.warning(error_text) return else: raise e def create_function(self, api_id: str, func_params: dict): params = dict_keys_to_camel_case(func_params) params['apiId'] = api_id return self.client.create_function(**params)['functionConfiguration'] def create_resolver(self, api_id: str, type_name: str, field_name: str, kind: str, runtime: str = None, data_source_name: str = None, code: str = None, request_mapping_template: str = None, response_mapping_template: str = None, max_batch_size: int = None, pipeline_config: dict = None): params = dict( apiId=api_id, typeName=type_name, fieldName=field_name, kind=kind ) if runtime: params['runtime'] = runtime if data_source_name: params['dataSourceName'] = data_source_name if code: params['code'] = code if request_mapping_template: params['requestMappingTemplate'] = request_mapping_template if response_mapping_template: params['responseMappingTemplate'] = response_mapping_template if pipeline_config: params['pipelineConfig'] = pipeline_config if max_batch_size: params['maxBatchSize'] = max_batch_size try: return self.client.create_resolver(**params)['resolver'] except self.client.exceptions.BadRequestException as e: if REDUNDANT_RESOLVER_EXCEPTION_TEXT in str(e): _LOG.warning(f'Only one resolver is allowed per field ' f'{field_name}; type {type_name}. ' f'Ignoring redundant resolver.') return else: raise e def create_api_key(self, api_id: str, description: str = None, expires: int = None): params = dict( apiId=api_id ) if description: params['description'] = description if expires: params['expires'] = expires return self.client.create_api_key(**params)['apiKey'] # ------------------------ Get ------------------------ def get_graphql_api(self, api_id: str): return self.client.get_graphql_api(apiId=api_id)['graphqlApi'] def get_data_source(self, api_id: str, name: str): try: return self.client.get_data_source( apiId=api_id, name=name)['dataSource'] except self.client.exceptions.NotFoundException: _LOG.warning(f'Data source {name} not found') return def get_graphql_api_by_name(self, name): # TODO change list_graphql_apis to list_apis when upgrade boto3 version def process_apis(resume_token=None): pagination_conf = { 'MaxItems': 60, 'PageSize': 10 } if resume_token: pagination_conf['StartingToken'] = resume_token response = paginator.paginate( PaginationConfig=pagination_conf ) for page in response: apis.extend( [api for api in page['graphqlApis'] if api['name'] == name] ) return response.resume_token apis = [] paginator = self.client.get_paginator('list_graphql_apis') next_token = process_apis() while next_token: next_token = process_apis(next_token) if len(apis) == 1: return apis[0] if len(apis) > 1: _LOG.warn(f'AppSync API can\'t be identified unambiguously ' f'because there is more than one resource with the name ' f'"{name}" in the region {self.region}.') else: _LOG.warn(f'AppSync API with the name "{name}" ' f'not found in the region {self.region}') def get_resolver(self, api_id: str, type_name: str, field_name: str): try: return self.client.get_resolver( apiId=api_id, typeName=type_name, fieldName=field_name )['resolver'] except self.client.exceptions.NotFoundException: _LOG.warning(f'Resolver for type {type_name} and field ' f'{field_name} not found') return def list_resolvers(self, api_id: str, type_name: str): result = [] try: paginator = self.client.get_paginator('list_resolvers') for response in paginator.paginate(apiId=api_id, typeName=type_name): result.extend(response['resolvers']) except self.client.exceptions.NotFoundException: return return result def list_data_sources(self, api_id: str) -> list | None: result = [] try: paginator = self.client.get_paginator('list_data_sources') for response in paginator.paginate(apiId=api_id): result.extend(response['dataSources']) except self.client.exceptions.NotFoundException: return return result def list_types(self, api_id: str) -> list | None: result = [] try: paginator = self.client.get_paginator('list_types') for response in paginator.paginate(apiId=api_id, format='JSON'): result.extend(response['types']) except self.client.exceptions.NotFoundException: return return result def list_functions(self, api_id: str) -> list: result = [] try: paginator = self.client.get_paginator('list_functions') for response in paginator.paginate(apiId=api_id): result.extend(response['functions']) except self.client.exceptions.NotFoundException: pass return result def get_schema(self, api_id: str, format: str = None): return self.client.get_introspection_schema( apiId=api_id, format='SDL' if not format else format )['schema'] def list_api_keys(self, api_id: str) -> list: return self.client.list_api_keys(apiId=api_id)['apiKeys'] # ------------------------ Update ------------------------ def update_data_source(self, api_id: str, name: str, source_type: str, source_config: dict = None, description: dict = None, service_role_arn: str = None): params = dict( apiId=api_id, name=name, type=source_type ) config_key = DATA_SOURCE_TYPE_CONFIG_MAPPING.get(source_type) if config_key: source_config = dict_keys_to_camel_case(source_config) params[config_key] = source_config if description: params['description'] = description if service_role_arn: params['serviceRoleArn'] = service_role_arn return self.client.update_data_source(**params)['dataSource'] def update_function(self, api_id: str, function_id: str, func_params: dict): params = dict_keys_to_camel_case(func_params) params.update({ 'apiId': api_id, 'functionId': function_id }) return self.client.update_function(**params)['functionConfiguration'] def update_resolver(self, api_id: str, type_name: str, field_name: str, kind: str, runtime: str = None, data_source_name: str = None, request_mapping_template: str = None, response_mapping_template: str = None, code: str = None, max_batch_size: int = None, pipeline_config: dict = None): params = dict( apiId=api_id, typeName=type_name, fieldName=field_name, kind=kind ) if runtime: params['runtime'] = runtime if data_source_name: params['dataSourceName'] = data_source_name if code: params['code'] = code if request_mapping_template: params['requestMappingTemplate'] = request_mapping_template if response_mapping_template: params['responseMappingTemplate'] = response_mapping_template if max_batch_size: params['maxBatchSize'] = max_batch_size if pipeline_config: params['pipelineConfig'] = pipeline_config return self.client.update_resolver(**params)['resolver'] def update_graphql_api(self, api_id: str, name: str, log_config: dict = None, auth_type: str = None, user_pool_config: dict = None, open_id_config: dict = None, lambda_auth_config: dict = None, xray_enabled: bool = None, extra_auth_types: list = None): params = dict( apiId=api_id, name=name ) if auth_type: params['authenticationType'] = auth_type if user_pool_config: user_pool_config = dict_keys_to_camel_case(user_pool_config) params['userPoolConfig'] = user_pool_config if open_id_config: open_id_config = dict_keys_to_camel_case(open_id_config) params['openIDConnectConfig'] = open_id_config if lambda_auth_config: lambda_auth_config = dict_keys_to_camel_case(lambda_auth_config) params['lambdaAuthorizerConfig'] = lambda_auth_config if log_config: params['logConfig'] = dict_keys_to_camel_case(log_config) if xray_enabled: params['xrayEnabled'] = xray_enabled if extra_auth_types: params['additionalAuthenticationProviders'] = extra_auth_types return self.client.update_graphql_api(**params)['graphqlApi']['apiId'] # ------------------------ Delete ------------------------ def delete_graphql_api(self, api_id: str): self.client.delete_graphql_api(apiId=api_id) def delete_data_source(self, api_id: str, name: str): return self.client.delete_data_source( apiId=api_id, name=name ) def delete_function(self, api_id: str, func_id: str): try: return self.client.delete_function( apiId=api_id, functionId=func_id ) except Exception as e: message = ('Cannot delete a function which is currently used by a ' 'resolver') if message in str(e): _LOG.warn(str(e)) else: raise def delete_resolver(self, api_id: str, type_name: str, field_name: str): return self.client.delete_resolver( apiId=api_id, typeName=type_name, fieldName=field_name )