cli/srecli/service/helpers.py (203 lines of code) (raw):

import base64 import json import re import secrets import string import sys import time import urllib.error import urllib.request import uuid from datetime import datetime, timezone from typing import Callable, TypeVar from dateutil.parser import isoparse from urllib3.exceptions import LocationParseError from urllib3.util import parse_url def urljoin(*args: str) -> str: """ This method somehow differs from urllib.parse.urljoin. See: >>> urljoin('one', 'two', 'three') 'one/two/three' >>> urljoin('one/', '/two/', '/three/') 'one/two/three' >>> urljoin('https://example.com/', '/prefix', 'path/to/service') 'https://example.com/prefix/path/to/service' :param args: list of string :return: """ return '/'.join(map(lambda x: str(x).strip('/'), args)) def sifted(data: dict) -> dict: """ >>> sifted({'k': 'value', 'k1': None, 'k2': '', 'k3': 0, 'k4': False}) {'k': 'value', 'k3': 0, 'k4': False} :param data: :return: """ return {k: v for k, v in data.items() if isinstance(v, (bool, int)) or v} class JWTToken: """ A simple wrapper over jwt token """ EXP_THRESHOLD = 300 # in seconds __slots__ = '_token', '_exp_threshold' def __init__(self, token: str, exp_threshold: int = EXP_THRESHOLD): self._token = token self._exp_threshold = exp_threshold @property def raw(self) -> str: return self._token @property def payload(self) -> dict | None: try: return json.loads( base64.b64decode(self._token.split('.')[1] + '==').decode() ) except Exception: return def is_expired(self) -> bool: p = self.payload if not p: return True exp = p.get('exp') if not exp: return False return exp < time.time() + self._exp_threshold def gen_password(digits: int = 20) -> str: allowed_punctuation = ''.join(set(string.punctuation) - {'"', "'", "!"}) chars = string.ascii_letters + string.digits + allowed_punctuation while True: password = ''.join(secrets.choice(chars) for _ in range(digits)) + '=' if (any(c.islower() for c in password) and any(c.isupper() for c in password) and sum(c.isdigit() for c in password) >= 3): break return password def utc_datetime(_from: str | None = None) -> datetime: """ Returns time-zone aware datetime object in UTC. You can optionally pass an existing ISO string. The function will parse it to object and make it UTC if it's not :params _from: Optional[str] :returns: datetime """ obj = datetime.now(timezone.utc) if not _from else isoparse(_from) return obj.astimezone(timezone.utc) def utc_iso(_from: datetime | None = None) -> str: """ Returns time-zone aware datetime ISO string in UTC with military suffix. You can optionally pass datetime object. The function will make it UTC if it's not and serialize to string :param _from: Optional[datetime] :returns: str """ obj = _from or utc_datetime() return obj.astimezone(timezone.utc).isoformat().replace('+00:00', 'Z') def build_cloudtrail_record(cloud_identifier: str, region: str, event_source: str, event_name: str) -> dict: return { "eventTime": utc_iso(), "awsRegion": region, "userIdentity": { "accountId": cloud_identifier }, "eventSource": event_source, "eventName": event_name } def normalize_lists(lists: list[list[str]]): """ Changes the given lists in place making them all equal length by repeating the last attr the necessary number of times. :param lists: :return: """ lens = [len(l) for l in lists] if not all(lens): raise ValueError('Each list must have at least one value') max_len = max(lens) for l in lists: l_len = len(l) if l_len < max_len: l += [l[-1] for _ in range(max_len - l_len)] assert len(set(len(l) for l in lists)) == 1 # equal lens def build_cloudtrail_records(cloud_identifier: list, region: list, event_source: list, event_name: list) -> list: """ Builds CloudTrail log records based on given params. If you still don't get it just execute the function with some random parameters (no validation of parameters content provided) and see the result. """ records = [] lists = [cloud_identifier, region, event_source, event_name] normalize_lists(lists) for i in range(len(lists[0])): records.append( build_cloudtrail_record(cloud_identifier[i], region[i], event_source[i], event_name[i])) return records def build_eventbridge_record(detail_type: str, source: str, account: str, region: str, detail: dict) -> dict: return { "version": "0", "id": str(uuid.uuid4()), "detail-type": detail_type, "source": source, "account": account, "time": utc_iso(), "region": region, "resources": [], "detail": detail } def build_maestro_record(event_action: str, group: str, sub_group: str, tenant_name: str, cloud: str): """ Only necessary attributes are kept :param event_action: :param group: :param sub_group: :param tenant_name: :param cloud: :return: """ return { "_id": str(uuid.uuid4()), "eventAction": event_action, "group": group, "subGroup": sub_group, "tenantName": tenant_name, "eventMetadata": { "request": {'cloud': cloud}, # todo native maestro events have string here "cloud": cloud } } def validate_api_link(url: str) -> str | None: url = url.lstrip() if "://" in url and not url.lower().startswith("http"): return 'Invalid API link: not supported scheme' try: scheme, auth, host, port, path, query, fragment = parse_url(url) except LocationParseError as e: return 'Invalid API link' if not scheme: return 'Invalid API link: missing scheme' if not host: return 'Invalid API link: missing host' try: req = urllib.request.Request(url) urllib.request.urlopen(req) except urllib.error.HTTPError as e: pass except urllib.error.URLError as e: return 'Invalid API link: cannot make a request' RT = TypeVar('RT') # return type ET = TypeVar('ET', bound=Exception) # exception type def catch(func: Callable[[], RT], exception: type[ET] = Exception ) -> tuple[RT | None, ET | None]: """ Calls the provided function and catches the desired exception. Seems useful to me :) ? :param func: :param exception: :return: """ try: return func(), None except exception as e: return None, e class Version(tuple): """ Limited version. Additional labels, pre-release labels and build metadata are not supported. Tuple with three elements (integers): (Major, Minor, Patch). Minor and Patch can be missing. It that case they are 0. This class is supposed to be used primarily by rulesets versioning """ _not_allowed = re.compile(r'[^.0-9]') def __new__(cls, seq: str | tuple[int, int, int] = (0, 0, 0) ) -> 'Version': if isinstance(seq, Version): return seq if isinstance(seq, str): seq = cls._parse(seq) return tuple.__new__(Version, seq) @classmethod def _parse(cls, version: str) -> tuple[int, int, int]: """ Raises ValueError """ prepared = re.sub(cls._not_allowed, '', version).strip('.') items = tuple(map(int, prepared.split('.'))) match len(items): case 3: return items case 2: return items[0], items[1], 0 case 1: return items[0], 0, 0 case _: raise ValueError( 'Cannot parse. Version must have one of formats: 1, 2.3, 4.5.6' ) @property def major(self) -> int: return self[0] @property def minor(self) -> int: return self[1] @property def patch(self) -> int | None: return self[2] def to_str(self) -> str: return '.'.join(map(str, self)) def __str__(self) -> str: return self.to_str() def check_version_compatibility(api: str, cli: str, /) -> None: if not api: print('Custodian API did not return the version number!', file=sys.stderr) return cli_version = Version(cli) api_version = Version(api) if cli_version > api_version: print(f'Consider that you SRE CLI version {cli_version} is ' f'higher than the API version {api_version}', file=sys.stderr) return if cli_version.major < api_version.major: print(f'CLI Major version {cli_version} is lower than ' f'the API version {api_version}. Please, update the CLI', file=sys.stderr) sys.exit(1) if cli_version.minor < api_version.minor: print(f'CLI Minor version {cli_version} is lower than the ' f'API version {api_version}. Some features may not ' f'work. Consider updating the SRE CLI', file=sys.stderr)