src/services/clients/cognito.py (338 lines of code) (raw):
from abc import ABC, abstractmethod
from datetime import datetime
from functools import cached_property
from http import HTTPStatus
from typing import Generator, Iterator, TYPE_CHECKING
from typing_extensions import NotRequired, Self, TypedDict
from botocore.exceptions import ClientError
from helpers.constants import (
CUSTOM_CUSTOMER_ATTR,
CUSTOM_LATEST_LOGIN_ATTR,
CUSTOM_ROLE_ATTR,
)
from helpers.lambda_response import ResponseFactory
from helpers.log_helper import get_logger
from helpers.time_helper import utc_datetime, utc_iso
from services.environment_service import EnvironmentService
from services.clients import Boto3ClientFactory
if TYPE_CHECKING:
from models.user import User
from botocore.client import BaseClient
_LOG = get_logger(__name__)
class _CognitoUserAttr(TypedDict):
Name: str
Value: str
class CognitoUserModel(TypedDict):
Username: str
UserAttributes: NotRequired[list[_CognitoUserAttr]] # either this one
Attributes: NotRequired[list[_CognitoUserAttr]] # or this one
UserCreateDate: datetime
UserLastModifiedDate: datetime
Enabled: bool
class UserWrapper:
__slots__ = ('id', 'username', 'customer', 'role', 'latest_login',
'created_at')
def __init__(self, username: str, customer: str | None = None,
role: str | None = None, latest_login: datetime | None = None,
created_at: datetime | None = None, sub: str | None = None):
"""
Sub is not used currently, so it's not important. Username represents
user id
:param username:
:param customer:
:param role:
:param latest_login:
:param created_at:
:param sub:
"""
self.username = username
self.customer = customer
self.role = role
self.latest_login = latest_login
self.created_at = created_at
self.id = sub
@classmethod
def from_user_model(cls, user: 'User') -> Self:
ll = None
if user.latest_login:
ll = utc_datetime(user.latest_login)
ca = None
if user.created_at:
ca = utc_datetime(user.created_at)
return cls(
sub=str(user.mongo_id), # noqa valid for onprem
username=user.user_id,
customer=user.customer,
role=user.role,
latest_login=ll,
created_at=ca
)
@classmethod
def from_cognito_model(cls, model: CognitoUserModel) -> Self:
attrs = model.get('UserAttributes') or model.get('Attributes') or ()
attributes = {a['Name']: a['Value'] for a in attrs}
ll = None
if item := attributes.get(CUSTOM_LATEST_LOGIN_ATTR):
ll = utc_datetime(item)
return cls(
sub=attributes.get('sub'), # valid for onprem
username=model['Username'],
customer=attributes.get(CUSTOM_CUSTOMER_ATTR),
role=attributes.get(CUSTOM_ROLE_ATTR),
latest_login=ll,
created_at=model['UserCreateDate']
)
def get_dto(self) -> dict:
return {
'username': self.username,
'customer': self.customer,
'role': self.role,
'latest_login': utc_iso(
self.latest_login) if self.latest_login else None,
'created_at': utc_iso(self.created_at) if self.created_at else None
}
class UsersIterator(Iterator[UserWrapper]):
next_token: str | int | None = None
def __iter__(self):
return self
def __next__(self) -> UserWrapper:
raise NotImplementedError
class CognitoUsersIterator(UsersIterator):
__slots__ = '_cl', '_upi', '_customer', '_limit', 'next_token'
def __init__(self, client: 'BaseClient', user_pool_id: str,
customer: str | None = None,
limit: int | None = None, next_token: str | None = None):
self._cl = client
self._upi = user_pool_id
self._customer = customer
self._limit = limit
self.next_token = next_token
def _get_next_page(self, limit: int | None = None,
token: str | None = None
) -> tuple[list[CognitoUserModel], str | None]:
params = dict(UserPoolId=self._upi)
if limit:
params['Limit'] = limit
if token:
params['PaginationToken'] = token
try:
res = self._cl.list_users(**params)
return res.get('Users') or [], res.get('PaginationToken')
except ClientError:
_LOG.warning('Unexpected error occurred listing users',
exc_info=True)
return [], None
def __iter__(self) -> Generator[UserWrapper, None, None]:
# local vars
_limit = self._limit
first = True
customer = self._customer
while _limit != 0 and (first or self.next_token):
res = self._get_next_page(_limit, self.next_token)
first = False
self.next_token = res[1]
for user in map(UserWrapper.from_cognito_model, res[0]):
if customer and user.customer != customer:
continue
yield user
if _limit is not None:
_limit -= 1
class AuthenticationResult(TypedDict):
id_token: str
refresh_token: str | None
expires_in: int
class BaseAuthClient(ABC):
@abstractmethod
def get_user_by_username(self, username: str) -> UserWrapper | None:
pass
@abstractmethod
def query_users(self, customer: str | None = None,
limit: int | None = None,
next_token: str | dict | None = None) -> UsersIterator:
pass
@abstractmethod
def set_user_password(self, username: str, password: str) -> bool:
pass
@abstractmethod
def update_user_attributes(self, user: UserWrapper):
"""
Updates all the attributes that are not equal to None in user wrapper
:param user:
:return:
"""
@abstractmethod
def delete_user(self, username: str) -> None:
pass
@abstractmethod
def authenticate_user(self, username: str, password: str
) -> AuthenticationResult | None:
pass
@abstractmethod
def refresh_token(self, refresh_token: str) -> AuthenticationResult | None:
pass
@abstractmethod
def signup_user(self, username: str, password: str,
customer: str | None = None, role: str | None = None) -> UserWrapper:
pass
def does_user_exist(self, username: str) -> bool:
"""
Use only if you don't need the user's data
:param username:
:return:
"""
return not not self.get_user_by_username(username)
class CognitoClient(BaseAuthClient):
def __init__(self, environment_service: EnvironmentService):
self._env = environment_service
@cached_property
def client(self):
return Boto3ClientFactory('cognito-idp').build(region_name=self._env.aws_region())
@property
def user_pool_name(self) -> str:
return self._env.get_user_pool_name()
@cached_property
def user_pool_id(self) -> str:
_LOG.info('Retrieving user pool id')
_id = self._env.get_user_pool_id()
if not _id:
_LOG.warning('User pool id is not found in envs. '
'Scanning all the available pools to get the id')
_id = self._pool_id_from_name(self.user_pool_name)
if not _id:
_message = 'Application Authentication Service is ' \
'not configured properly.'
_LOG.error(f'User pool \'{self.user_pool_name}\' does '
f'not exists. {_message}')
raise ResponseFactory(HTTPStatus.SERVICE_UNAVAILABLE).message(
_message).exc()
return _id
@property
def client_id(self) -> str:
client = self.client.list_user_pool_clients(
UserPoolId=self.user_pool_id, MaxResults=1)['UserPoolClients']
if not client:
_message = 'Application Authentication Service is not ' \
'configured properly: no client applications found'
_LOG.error(_message)
raise ResponseFactory(HTTPStatus.SERVICE_UNAVAILABLE).message(
_message
).exc()
return client[0]['ClientId']
def _pool_id_from_name(self, name: str) -> str | None:
"""
Since AWS Cognito can have two different pools with equal names,
this method returns the first pool id which will be found.
"""
for pool in self._list_user_pools():
if pool['Name'] == name:
return pool['Id']
def _list_user_pools(self) -> Generator[dict, None, None]:
first = True
params = dict(MaxResults=10)
while params.get('NextToken') or first:
pools = self.client.list_user_pools(**params)
yield from pools.get('UserPools') or []
params['NextToken'] = pools.get('NextToken')
if first:
first = False
def get_user_by_username(self, username: str) -> UserWrapper | None:
try:
item = self.client.admin_get_user(
UserPoolId=self.user_pool_id,
Username=username
)
_LOG.debug(f'Result of admin_get_user: {item}')
return UserWrapper.from_cognito_model(item)
except ClientError as e:
_LOG.warning(f'ClientError occurred querying a user, {e}')
return
def query_users(self, customer: str | None = None,
limit: int | None = None,
next_token: str | None = None) -> UsersIterator:
return CognitoUsersIterator(
client=self.client,
user_pool_id=self.user_pool_id,
limit=limit,
next_token=next_token,
customer=customer
)
def set_user_password(self, username: str, password: str) -> bool:
try:
self.client.admin_set_user_password(
UserPoolId=self.user_pool_id,
Username=username,
Password=password,
Permanent=True
)
return True
except ClientError:
_LOG.warning('Could not set user password due to client error',
exc_info=True)
return False
def update_user_attributes(self, user: UserWrapper):
def attr(n, v):
return dict(Name=n, Value=v)
attributes = []
if user.customer:
attributes.append(attr(CUSTOM_CUSTOMER_ATTR, user.customer))
if user.role:
attributes.append(attr(CUSTOM_ROLE_ATTR, user.customer))
if user.latest_login:
attributes.append(attr(CUSTOM_LATEST_LOGIN_ATTR,
utc_iso(user.latest_login)))
if attributes:
self.client.admin_update_user_attributes(
UserPoolId=self.user_pool_id,
Username=user.username,
UserAttributes=attributes
)
def delete_user(self, username: str) -> None:
try:
self.client.admin_delete_user(
UserPoolId=self.user_pool_id,
Username=username
)
except ClientError as e:
if e.response['Error']['Code'] == 'UserNotFoundException':
pass
raise e
def authenticate_user(self, username: str, password: str
) -> AuthenticationResult | None:
try:
r = self.client.admin_initiate_auth(
UserPoolId=self.user_pool_id,
ClientId=self.client_id,
AuthFlow='ADMIN_NO_SRP_AUTH',
AuthParameters={'USERNAME': username, 'PASSWORD': password}
)
return {
'id_token': r['AuthenticationResult']['IdToken'],
'refresh_token': r['AuthenticationResult'].get('RefreshToken'),
'expires_in': r['AuthenticationResult']['ExpiresIn']
}
except self.client.exceptions.UserNotFoundException:
return
except self.client.exceptions.NotAuthorizedException:
return
def refresh_token(self, refresh_token: str) -> AuthenticationResult | None:
try:
r = self.client.admin_initiate_auth(
UserPoolId=self.user_pool_id,
ClientId=self.client_id,
AuthFlow='REFRESH_TOKEN_AUTH',
AuthParameters={'REFRESH_TOKEN': refresh_token}
)
return {
'id_token': r['AuthenticationResult']['IdToken'],
'refresh_token': r['AuthenticationResult'].get('RefreshToken'),
'expires_in': r['AuthenticationResult']['ExpiresIn']
}
except ClientError:
_LOG.warning('Client error occurred trying to refresh token',
exc_info=True)
def signup_user(self, username: str, password: str,
customer: str | None = None, role: str | None = None
) -> UserWrapper:
def attr(n, v):
return dict(Name=n, Value=v)
attrs = [attr('name', username)]
if customer:
attrs.append(attr(CUSTOM_CUSTOMER_ATTR, customer))
if role:
attrs.append(attr(CUSTOM_ROLE_ATTR, role))
validation_data = [attr('name', username)]
res = self.client.sign_up(
ClientId=self.client_id,
Username=username,
Password=password,
UserAttributes=attrs,
ValidationData=validation_data
)
self.client.admin_set_user_password(
UserPoolId=self.user_pool_id,
Username=username,
Password=password,
Permanent=True
)
return UserWrapper(
username=username,
customer=customer,
role=role,
created_at=utc_datetime(),
sub=res['UserSub'],
)