modular_sdk/services/sts_service.py (142 lines of code) (raw):
import re
import uuid
from datetime import datetime, timedelta
from functools import cached_property
from time import time
from typing import Optional, TypedDict, List, Tuple, Generator
from botocore.exceptions import ClientError
from modular_sdk.commons.constants import MODULAR_AWS_ACCESS_KEY_ID_ENV, \
MODULAR_AWS_SESSION_TOKEN_ENV, MODULAR_AWS_SECRET_ACCESS_KEY_ENV, \
MODULAR_AWS_CREDENTIALS_EXPIRATION_ENV
from modular_sdk.commons.log_helper import get_logger
from modular_sdk.commons.time_helper import utc_datetime
from modular_sdk.services.aws_creds_provider import AWSCredentialsProvider
from modular_sdk.services.environment_service import EnvironmentService
_LOG = get_logger(__name__)
class StsService(AWSCredentialsProvider):
class AssumeRoleResult(TypedDict):
aws_access_key_id: str
aws_secret_access_key: str
aws_session_token: str
expiration: datetime
AssumeRolePayload = Tuple[str, Optional[str], Optional[int]] # role arn, session_name, duration
def __init__(self, environment_service: EnvironmentService,
aws_region, aws_access_key_id=None,
aws_secret_access_key=None, aws_session_token=None):
super(StsService, self).__init__(
service_name='sts',
aws_region=aws_region,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token
)
self._environment_service = environment_service
@staticmethod
def generate_unique_session_name(name_body, suffix_length=6):
suffix = str(uuid.uuid4())[-suffix_length:]
return f'{name_body}-{suffix}'
@cached_property
def caller_identity(self):
return self.client.get_caller_identity()
@cached_property
def role_arn_pattern(self) -> re.Pattern:
return re.compile(r'^arn:aws:iam::\d{12}:role/[A-Za-z0-9_-]+$')
def get_account_id(self) -> str:
# _id = self._environment_service.account_id()
# if not _id:
# _LOG.warning('Valid account id not found in envs. '
# 'Calling \'get_caller_identity\'')
# _id = self.get_caller_identity()['Account']
return self.caller_identity['Account']
def build_role_arn(self, maybe_arn: str,
account_id: Optional[str] = None) -> str:
if self.is_role_arn(maybe_arn):
return maybe_arn
account_id = account_id or self.get_account_id()
return f'arn:aws:iam::{account_id}:role/{maybe_arn}'
def is_role_arn(self, arn: str) -> bool:
return bool(re.match(self.role_arn_pattern, arn))
def assume_role(self, role_arn, role_session_name,
duration=900) -> AssumeRoleResult:
try:
response = self.client.assume_role(
RoleArn=role_arn,
RoleSessionName=role_session_name,
DurationSeconds=duration
)
except ClientError as e:
error_message = f'Error while assuming {role_arn}'
_LOG.error(f'{error_message}: {e}')
raise ConnectionAbortedError(error_message) from e
credentials = response.get('Credentials')
return {
'aws_access_key_id': credentials.get('AccessKeyId'),
'aws_secret_access_key': credentials.get('SecretAccessKey'),
'aws_session_token': credentials.get('SessionToken'),
'expiration': credentials.get('Expiration') # datetime UTC
}
def assume_roles_chain(self, payloads: List[AssumeRolePayload]
) -> AssumeRoleResult:
"""
It assumes a chain of roles one after another.
:param payloads:
:return: AssumeRoleResult
Returns the credentials and expiration of the last assumed role
"""
assert payloads, 'At least one payload must be given'
_sts = self.client
for payload in payloads:
assert len(payload) == 3, 'Invalid usage of the method'
arn = payload[0]
session_name = payload[1] or f'modular_sdk-sdk-session-{time()}'
duration = payload[2] or 900
try:
_LOG.info(f'Assuming {arn} from chain')
creds = _sts.assume_role(
RoleArn=arn,
RoleSessionName=session_name,
DurationSeconds=duration
)['Credentials']
except ClientError as e:
error_message = f'Error while assuming {arn} from chain'
_LOG.error(f'{error_message}: {e}')
raise ConnectionAbortedError(error_message) from e
_sts = AWSCredentialsProvider(
service_name='sts',
aws_region=self._region_name,
aws_access_key_id=creds['AccessKeyId'],
aws_secret_access_key=creds['SecretAccessKey'],
aws_session_token=creds['SessionToken'],
).client
# creds variable will exist, ignore warning
return {
'aws_access_key_id': creds.get('AccessKeyId'),
'aws_secret_access_key': creds.get('SecretAccessKey'),
'aws_session_token': creds.get('SessionToken'),
'expiration': creds.get('Expiration') # datetime UTC
}
def assume_roles_default_payloads(self, roles: List[str],
session_name: Optional[str] = None,
last_duration: Optional[int] = 3600
) -> Generator[AssumeRolePayload, None, None]:
"""
Just to keep the same code in one place. We could have easily
done without this method.
This method puts session duration for the last payload. Additionally,
it validated roles ARNs and skips one in case it's invalid. Session
names are currently None because they do not matter
:param session_name: name for each session. Current timestamp
will be added.
:param last_duration: session duration for the last payload
:param roles:
:return:
"""
session = lambda: f'{session_name}-{time()}' if session_name else None
n = len(roles)
for i, role in enumerate(roles):
if not self.is_role_arn(role):
_LOG.warning(f'The string {roles} is not a role arn. '
f'Skipping it.')
continue
_dur = last_duration if i == (n - 1) else 900 # 900 the smallest
yield role, session(), _dur
def assure_modular_credentials_valid(self) -> bool:
"""
If modular_sdk uses 'modular_assume_role_arn', it uses temp aws credentials
to be able to interact with another AWS account. The creds
expiration is kept in envs. They must be re-assumed periodically. So,
this method checks whether the creds are about to expire and if
they really are - re-assumes them.
Returns True in case the role was re-assumed. Otherwise - False
"""
roles = self._environment_service.modular_assume_role_arn()
if not roles:
return False
ex = self._environment_service.modular_aws_credentials_expiration()
in_a_while = utc_datetime() + timedelta(minutes=5)
if not ex or in_a_while > datetime.fromisoformat(ex):
_LOG.info(f'Role {roles[-1]} has not been assumed or has expired. '
f'Reassuming the chain: {roles}')
creds = self.assume_roles_chain(
list(self.assume_roles_default_payloads(roles))
)
_LOG.debug(f'Credentials received successfully. '
f'Setting them to envs')
ak, sk = creds['aws_access_key_id'], creds['aws_secret_access_key']
st = creds['aws_session_token']
self._environment_service.set(MODULAR_AWS_ACCESS_KEY_ID_ENV, ak)
self._environment_service.set(MODULAR_AWS_SECRET_ACCESS_KEY_ENV, sk)
self._environment_service.set(MODULAR_AWS_SESSION_TOKEN_ENV, st)
self._environment_service.set(MODULAR_AWS_CREDENTIALS_EXPIRATION_ENV,
creds['expiration'].isoformat())
return True
else:
return False