syndicate/connection/batch_connection.py (201 lines of code) (raw):
"""
Copyright 2018 EPAM Systems, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from boto3 import client
from botocore.waiter import WaiterModel, create_waiter_with_client
from syndicate.commons.log_helper import get_logger
from syndicate.connection.helper import apply_methods_decorator, retry
from syndicate.core.helper import dict_keys_to_camel_case
_LOG = get_logger(__name__)
@apply_methods_decorator(retry())
class BatchConnection(object):
""" AWS Batch connection class. """
def __init__(self, region=None, aws_access_key_id=None,
aws_secret_access_key=None, aws_session_token=None):
self.client = client('batch', region,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token)
_LOG.debug('Opened new Batch connection.')
def create_compute_environment(self, compute_environment_name,
compute_environment_type, state,
service_role=None,
compute_resources=None,
tags=None):
params = dict(
computeEnvironmentName=compute_environment_name,
type=compute_environment_type,
state=state,
serviceRole=service_role,
)
if compute_resources:
compute_resources = dict_keys_to_camel_case(compute_resources)
params['computeResources'] = compute_resources
if tags:
params['tags'] = tags
return self.client.create_compute_environment(**params)
def update_compute_environment(self, compute_environment, state=None,
compute_resources=None, service_role=None):
params = dict(computeEnvironment=compute_environment)
if state:
params['state'] = state
if compute_resources:
params['computeResources'] = dict_keys_to_camel_case(
compute_resources)
if service_role:
params['serviceRole'] = service_role
return self.client.update_compute_environment(**params)
def describe_compute_environments(self, compute_environments):
params = dict()
if isinstance(compute_environments, str):
params['computeEnvironments'] = [compute_environments]
if isinstance(compute_environments, list):
params['computeEnvironments'] = compute_environments
return self.client.describe_compute_environments(**params)
def delete_compute_environment(self, compute_environment):
return self.client.delete_compute_environment(
computeEnvironment=compute_environment)
def create_job_queue(self, job_queue_name, state, priority,
compute_environment_order, tags=None):
params = dict(
jobQueueName=job_queue_name,
state=state,
priority=priority,
)
for index, item in enumerate(compute_environment_order):
compute_environment_order[index] = dict_keys_to_camel_case(item)
params['computeEnvironmentOrder'] = compute_environment_order
if tags:
params['tags'] = tags
return self.client.create_job_queue(**params)
def describe_job_queue(self, job_queues=None, max_results=None,
next_token=None):
params = dict()
if not job_queues:
params['jobQueues'] = []
if isinstance(job_queues, str):
params['jobQueues'] = [job_queues]
if isinstance(job_queues, list):
params['jobQueues'] = job_queues
if max_results:
params['maxResults'] = max_results
if next_token:
params['nextToken'] = next_token
return self.client.describe_job_queues(**params)
def update_job_queue(self, job_queue, state=None, priority=None,
compute_environment_order=None):
params = dict(jobQueue=job_queue)
if state:
params['state'] = state
if priority:
params['priority'] = priority
if compute_environment_order:
params['computeEnvironmentOrder'] = compute_environment_order
return self.client.update_job_queue(**params)
def delete_job_queue(self, job_queue):
return self.client.delete_job_queue(
jobQueue=job_queue
)
def register_job_definition(self, job_definition_name, job_definition_type,
parameters=None,
container_properties=None,
node_properties=None, retry_strategy=None,
propagate_tags=None,
timeout=None, tags=None,
platform_capabilities=None):
params = dict(
jobDefinitionName=job_definition_name,
type=job_definition_type,
)
if parameters:
params['parameters'] = dict_keys_to_camel_case(parameters)
if container_properties:
params['containerProperties'] = dict_keys_to_camel_case(
container_properties)
if node_properties:
params['nodeProperties'] = dict_keys_to_camel_case(node_properties)
if retry_strategy:
params['retryStrategy'] = dict_keys_to_camel_case(retry_strategy)
if propagate_tags is not None:
params['propagateTags'] = propagate_tags
if timeout:
params['timeout'] = dict_keys_to_camel_case(timeout)
if tags:
params['tags'] = tags
if platform_capabilities:
params['platformCapabilities'] = platform_capabilities
return self.client.register_job_definition(**params)
def describe_job_definition(self, job_definition, max_results=None,
status=None):
params = dict(jobDefinitionName=job_definition)
if max_results:
params['maxResults'] = max_results
if status:
params['status'] = status
return self.client.describe_job_definitions(**params)
def deregister_job_definition(self, job_definition_name):
revisions = self._get_job_def_revisions(
job_definition_name=job_definition_name)
for revision in revisions:
job_definition = '{0}:{1}'.format(job_definition_name, revision)
self.client.deregister_job_definition(
jobDefinition=job_definition
)
def get_compute_environment_waiter(self):
waiter_id = 'ComputeEnvironmentWaiter'
model = WaiterModel({
'version': 2,
'waiters': {
waiter_id: {
'delay': 2,
'operation': 'DescribeComputeEnvironments',
'maxAttempts': 10,
'acceptors': [
{
'expected': 'VALID',
'matcher': 'pathAll',
'state': 'success',
'argument': 'computeEnvironments[].status'
},
{
'expected': 'INVALID',
'matcher': 'pathAny',
'state': 'failure',
'argument': 'computeEnvironments[].status'
}
]
}
}
})
return create_waiter_with_client(waiter_id, model, self.client)
def get_job_queue_waiter(self):
waiter_id = 'JobQueueWaiter'
model = WaiterModel({
'version': 2,
'waiters': {
waiter_id: {
'delay': 1,
'operation': 'DescribeJobQueues',
'maxAttempts': 10,
'acceptors': [
{
'expected': 'VALID',
'matcher': 'pathAll',
'state': 'success',
'argument': 'jobQueues[].status'
},
{
'expected': 'INVALID',
'matcher': 'pathAny',
'state': 'failure',
'argument': 'jobQueues[].status'
}
]
}
}
})
return create_waiter_with_client(waiter_id, model, self.client)
def _get_job_def_revisions(self, job_definition_name):
job_definition_data = self.describe_job_definition(
job_definition=job_definition_name)
revisions = []
for job_def in job_definition_data['jobDefinitions']:
if job_def.get('status') == 'ACTIVE':
revisions.append(job_def.get('revision'))
return revisions