src/services/ruleset_service.py (335 lines of code) (raw):

import hashlib from typing import BinaryIO, Generator, Iterable, Iterator, Optional import msgspec from helpers import Version from helpers.constants import ( COMPOUND_KEYS_SEPARATOR, ED_AWS_RULESET_NAME, ED_AZURE_RULESET_NAME, ED_GOOGLE_RULESET_NAME, ED_KUBERNETES_RULESET_NAME, ID_ATTR, LICENSED_ATTR, NAME_ATTR, RULES_ATTR, RULES_NUMBER, VERSION_ATTR, ) from helpers.system_customer import SYSTEM_CUSTOMER from helpers.time_helper import utc_iso from models.ruleset import RULESET_LICENSES, RULESET_STANDARD, Ruleset from services.base_data_service import BaseDataService from services.clients.s3 import S3Client class RulesetService(BaseDataService[Ruleset]): def __init__(self, s3_client: S3Client): super().__init__() self._s3_client = s3_client def iter_licensed( self, name: Optional[str] = None, version: Optional[str] = None, cloud: Optional[str] = None, ascending: bool = False, limit: Optional[int] = None, ) -> Iterator[Ruleset]: if version and not name: raise AssertionError('Invalid usage') filter_condition = None if cloud: filter_condition &= Ruleset.cloud == cloud.upper() sort_key = ( f'{SYSTEM_CUSTOMER}{COMPOUND_KEYS_SEPARATOR}' f'{RULESET_LICENSES}{COMPOUND_KEYS_SEPARATOR}' ) if name: sort_key += f'{name}{COMPOUND_KEYS_SEPARATOR}' if version: sort_key += f'{version}' return self.model_class.customer_id_index.query( hash_key=SYSTEM_CUSTOMER, range_key_condition=(self.model_class.id.startswith(sort_key)), scan_index_forward=ascending, limit=limit, ) def iter_standard( self, customer: str, name: Optional[str] = None, version: Optional[str] = None, cloud: Optional[str] = None, event_driven: Optional[bool] = False, ascending: Optional[bool] = False, limit: Optional[int] = None, **kwargs, ) -> Iterator[Ruleset]: if version and not name: raise AssertionError('Invalid usage') filter_condition = None if cloud: filter_condition &= Ruleset.cloud == cloud.upper() if isinstance(event_driven, bool): filter_condition &= Ruleset.event_driven == event_driven sort_key = ( f'{customer}{COMPOUND_KEYS_SEPARATOR}' f'{RULESET_STANDARD}{COMPOUND_KEYS_SEPARATOR}' ) if name: sort_key += f'{name}{COMPOUND_KEYS_SEPARATOR}' if version: sort_key += f'{version}' return self.model_class.customer_id_index.query( hash_key=customer, range_key_condition=(self.model_class.id.startswith(sort_key)), scan_index_forward=ascending, limit=limit, filter_condition=filter_condition, ) def by_id(self, id: str, attributes_to_get: list = None) -> Ruleset | None: return self.get_nullable( hash_key=id, attributes_to_get=attributes_to_get ) def iter_by_id(self, ids: Iterable[str]) -> Generator[Ruleset, None, None]: processed = set() for _id in ids: if _id in processed: continue ruleset = self.by_id(_id) if ruleset: yield ruleset processed.add(_id) def by_lm_id( self, lm_id: str, attributes_to_get: Optional[list] = None ) -> Optional[Ruleset]: return next( self.model_class.license_manager_id_index.query( hash_key=lm_id, limit=1, attributes_to_get=attributes_to_get ), None, ) def iter_by_lm_id( self, lm_ids: Iterable[str] ) -> Generator[Ruleset, None, None]: processed = set() for _id in lm_ids: if _id in processed: continue ruleset = self.by_lm_id(_id) if ruleset: yield ruleset processed.add(_id) def get_standard( self, customer: str, name: str, version: str ) -> Ruleset | None: return self.by_id( id=self.build_id( customer=customer, licensed=False, name=name, version=version ) ) def get_latest(self, customer: str, name: str) -> Ruleset | None: return next( self.iter_standard( customer=customer, name=name, ascending=False, limit=1 ), None, ) def create( self, customer: str, name: str, version: str, cloud: str, rules: list, event_driven: bool = False, s3_path: dict | None = None, status: dict | None = None, licensed: bool = False, license_keys: list | None = None, license_manager_id: str | None = None, versions: list[str] | None = None, created_at: str | None = None, description: str | None = None, **kwargs, ) -> Ruleset: s3_path = s3_path or {} status = status or {} license_keys = license_keys or [] return Ruleset( id=self.build_id(customer, licensed, name, version), customer=customer, cloud=cloud, event_driven=event_driven, rules=rules, s3_path=s3_path or {}, status=status or {}, license_keys=license_keys or [], license_manager_id=license_manager_id, created_at=created_at or utc_iso(), versions=versions or [], description=description, ) def create_event_driven( self, cloud: str, version: str, rules: list[str] ) -> Ruleset: return self.create( customer=SYSTEM_CUSTOMER, name=self.ed_ruleset_name(cloud), version=version, cloud=cloud, rules=rules, event_driven=True, licensed=False, description='System event driven ruleset', ) def get_previous_ruleset( self, ruleset: Ruleset, limit: Optional[int] = None ) -> Iterator[Ruleset]: """ Returns previous versions of the same ruleset :param ruleset: :param limit: :return: """ return self.model_class.customer_id_index.query( hash_key=ruleset.customer, range_key_condition=(self.model_class.id < ruleset.id), scan_index_forward=False, limit=limit, ) def build_id( self, customer: str, licensed: bool, name: str, version: str ) -> str: return COMPOUND_KEYS_SEPARATOR.join( map(str, (customer, self.licensed_tag(licensed), name, version)) ) @staticmethod def licensed_tag(licensed: bool) -> str: return RULESET_LICENSES if licensed else RULESET_STANDARD @staticmethod def build_s3_key(ruleset: Ruleset) -> str: return S3Client.safe_key( f'{ruleset.customer}/{ruleset.name}/{ruleset.version}' ) def delete(self, item: Ruleset): super().delete(item) s3_path = item.s3_path.as_dict() if s3_path: self._s3_client.gz_delete_object( bucket=s3_path.get('bucket_name'), key=s3_path.get('path') ) def dto(self, ruleset: Ruleset, params_to_exclude=None) -> dict: ruleset_json = ruleset.get_json() ruleset_json[RULES_NUMBER] = len(ruleset_json.get(RULES_ATTR) or []) ruleset_json[NAME_ATTR] = ruleset.name if v := ruleset.version: ruleset_json[VERSION_ATTR] = v ruleset_json[LICENSED_ATTR] = ruleset.licensed ruleset_json.pop(ID_ATTR, None) ruleset_json.pop('status', None) ruleset_json.pop('allowed_for', None) ruleset_json.pop('active', None) ruleset_json.pop('license_manager_id', None) ruleset_json.pop('s3_path', None) for param in params_to_exclude or (): ruleset_json.pop(param, None) return ruleset_json @staticmethod def set_s3_path(ruleset: Ruleset, bucket: str, key: str): ruleset.s3_path = {'bucket_name': bucket, 'path': key} def iter_event_driven( self, cloud: str, ascending: bool = False, limit: int | None = None ) -> Iterator[Ruleset]: """ Iterates over event-driven rulesets for cloud :param cloud: :param ascending: :param limit: :return: """ sk = self.build_id( customer=SYSTEM_CUSTOMER, licensed=False, name=self.ed_ruleset_name(cloud), version='', ) return Ruleset.customer_id_index.query( hash_key=SYSTEM_CUSTOMER, range_key_condition=Ruleset.id.startswith(sk), filter_condition=(Ruleset.event_driven == True), scan_index_forward=ascending, limit=limit, ) def get_latest_event_driven(self, cloud: str) -> Ruleset | None: return next(self.iter_event_driven(cloud, limit=1), None) def get_event_driven(self, cloud: str, version: str) -> Ruleset | None: return self.get_standard( customer=SYSTEM_CUSTOMER, name=self.ed_ruleset_name(cloud), version=version, ) def download(self, ruleset: Ruleset, out: BinaryIO = None) -> BinaryIO: return self._s3_client.gz_get_object( bucket=ruleset.s3_path['bucket_name'], key=ruleset.s3_path['path'], buffer=out, ) def download_url(self, ruleset: Ruleset) -> str: """ Returns a presigned url to the given file :param ruleset: :return: """ return self._s3_client.prepare_presigned_url( self._s3_client.gz_download_url( bucket=ruleset.s3_path['bucket_name'], key=ruleset.s3_path['path'], filename=ruleset.name, # not so important response_encoding='gzip', ) ) @staticmethod def payload_hash(payload: list | dict) -> str: if isinstance(payload, dict): policies = payload.get('policies') or [] else: policies = payload name_to_body = {p['name']: p for p in policies} return RulesetService.hash_from_name_to_body(name_to_body) @staticmethod def hash_from_name_to_body(name_to_body: dict) -> str: """ Calculates hash of ruleset's policies. Does not consider other from policies data :param name_to_body: dict where keys are names and value are bodies :return: """ data = msgspec.json.encode(name_to_body, order='deterministic') return hashlib.sha256(data).hexdigest() @staticmethod def ed_ruleset_name(cloud: str) -> str: match cloud: case 'AWS': return ED_AWS_RULESET_NAME case 'AZURE': return ED_AZURE_RULESET_NAME case 'GOOGLE' | 'GCP': return ED_GOOGLE_RULESET_NAME case 'KUBERNETES' | 'K8S': return ED_KUBERNETES_RULESET_NAME class RulesetName(tuple): @staticmethod def _parse_name(n: str) -> tuple[str, Version | None, str | None]: """ Name can be: - FULL_AWS - FULL_AWS:1.4.0 - 5131c559-ac8d-4842-b1a0-92c766b7ec8c:FULL_AWS - 5131c559-ac8d-4842-b1a0-92c766b7ec8c:FULL_AWS:1.7.0 :param n: :return: (name, version, license_key) :raises: ValueError """ items = n.strip().strip(':').split(':', maxsplit=2) match len(items): case 3: # all three are given return items[1], Version(items[2]), items[0] case 2: # name and version or license_key and name first, second = items try: return first, Version(second), None except ValueError: return second, None, first case _: # only name return items[0], None, None def __new__( cls, n: str, v: str | None = None, lk: str | None = None ) -> 'RulesetName': if isinstance(n, RulesetName): return n name, version, license_key = cls._parse_name(n) if v: version = Version(v) if lk: license_key = lk return tuple.__new__(RulesetName, (name, version, license_key)) @property def name(self) -> str: return self[0] @property def version(self) -> Version | None: return self[1] @property def license_key(self) -> str | None: return self[2] def to_str(self, include_license: bool = True) -> str: name = self.name if v := self.version: name = f'{name}:{v.to_str()}' if (lk := self.license_key) and include_license: name = f'{lk}:{name}' return name