modular_sdk/services/aws_creds_provider.py (85 lines of code) (raw):

from typing import Optional import boto3 from botocore.client import BaseClient from datetime import datetime, timedelta from botocore.config import Config from modular_sdk.modular import Modular from modular_sdk.commons.constants import Env from modular_sdk.commons.time_helper import utc_datetime from functools import cached_property from modular_sdk.commons.log_helper import get_logger _LOG = get_logger(__name__) class AWSCredentialsProvider: # client provider def __init__(self, service_name: str, aws_region: str, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_session_token: Optional[str] = None): if bool(aws_access_key_id) ^ bool(aws_secret_access_key): error_message = 'aws_access_key_id and aws_secret_access_key ' \ 'must be both specified.' raise KeyError(error_message) # RuntimeError self._service_name = service_name self._region_name = aws_region self._aws_access_key_id = aws_access_key_id self._aws_secret_access_key = aws_secret_access_key self._aws_session_token = aws_session_token @cached_property def client(self) -> BaseClient: _LOG.info(f'Initializing {self._service_name} boto3 client') return boto3.client( service_name=self._service_name, region_name=self._region_name, aws_access_key_id=self._aws_access_key_id, aws_secret_access_key=self._aws_secret_access_key, aws_session_token=self._aws_session_token, ) class ModularAssumeRoleClient: """ Descriptor class for boto3 client attributes when we need them to be refreshed when modular_sdk assume role creds expire. Example: class SSMClient: client = ModularAssumeRoleClient('ssm') def get_parameter(self, name): return self.client.get_parameter(Name=name, WithDecryption=True) Such a client will be automatically refreshed. """ session: boto3.Session = None # class var exp: datetime = None # class var, session's creds expiration def __init__(self, service_name: str, region_name: Optional[str] = None): # TODO add role_arn to input params self._service_name = service_name self._region_name = region_name self._client: Optional[BaseClient] = None @classmethod def get_session(cls) -> boto3.Session: if not cls.session: _LOG.info('Initializing boto3 session inside ' 'ModularAssumeRoleClient descriptor') cls.session = boto3.Session() return cls.session @classmethod def _expired(cls) -> bool: in_a_while = utc_datetime() + timedelta(minutes=5) return not isinstance(cls.exp, datetime) or in_a_while > cls.exp @classmethod def _update_session(cls, aws_access_key_id: str, aws_secret_access_key: str, aws_session_token: str, expiration: datetime): cls.get_session()._session.set_credentials( access_key=aws_access_key_id, secret_key=aws_secret_access_key, token=aws_session_token ) cls.exp = expiration def __get__(self, instance, owner) -> BaseClient: """ We cannot use sts.assure_modular_credentials_valid() and BaseRoleAccessModel's logic here as well (I mean creds from envs) because here we cannot catch the moment when creds were refreshed by models, for instance. Think about it """ _modular = Modular() sts, env = _modular.sts_service(), _modular.environment_service() roles = env.modular_assume_role_arn() if roles and self._expired(): _LOG.info('Boto3 session inside ModularAssumeRoleClient descriptor ' 'has expired. Re-assuming role') payloads = list(sts.assume_roles_default_payloads(roles)) creds = sts.assume_roles_chain(payloads) self._update_session(**creds) self._client = None if not self._client: r = self._region_name or env.modular_aws_region() or env.aws_region() _LOG.info(f'Initializing {self._service_name} client within ' f'ModularAssumeRoleClient descriptor for region {r}') self._client = self.get_session().client( service_name=self._service_name, region_name=r, ) return self._client