modular_sdk/models/pynamodb_extension/pynamodb_to_pymongo_adapter.py (372 lines of code) (raw):
import json
import json
import re
import decimal
from itertools import islice
from typing import Optional, Dict, List, Union, TypeVar, Iterator
from pymongo import DeleteOne, ReplaceOne, DESCENDING, ASCENDING
from pymongo.collection import Collection, ReturnDocument
from pymongo.errors import BulkWriteError
from pynamodb import indexes
from pynamodb.expressions.condition import Condition
from pynamodb.expressions.operand import Value, Path, _ListAppend
from pynamodb.expressions.update import SetAction, RemoveAction, Action
from pynamodb.models import Model
from pynamodb.settings import OperationSettings
from modular_sdk.commons import DynamoDBJsonSerializer
from modular_sdk.connections.mongodb_connection import MongoDBConnection
T = TypeVar('T')
class Result(Iterator[T]):
def __init__(self, result: Iterator[T],
_evaluated_key: Optional[int] = None,
page_size: Optional[int] = None):
self._result_it = result
self._evaluated_key = _evaluated_key
self._page_size = page_size
@property
def last_evaluated_key(self):
_key = self._evaluated_key
if _key is not None and _key < self._page_size:
return _key
def __iter__(self):
return self
def __next__(self) -> T:
item = self._result_it.__next__()
if self._evaluated_key is not None:
self._evaluated_key += 1
return item
class BatchWrite:
def __init__(self, model, mongo_connection):
self.collection_name = model.Meta.table_name
self.mongo_connection = mongo_connection
self.request = []
def save(self, put_item):
json_to_save = put_item.dynamodb_model()
json_to_save.pop('mongo_id', None)
encoded_document = self.mongo_connection.encode_keys(
{
key: value for key, value in json_to_save.items()
if value is not None
}
)
self.request.append(ReplaceOne(put_item.get_keys(),
encoded_document, upsert=True))
def delete(self, del_item):
self.request.append(DeleteOne(del_item._get_keys()))
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return self.commit()
def commit(self):
collection = self.mongo_connection.collection(
collection_name=self.collection_name)
if not self.request:
return
try:
collection.bulk_write(self.request)
except BulkWriteError:
pass
class _PynamoDBExpressionsConverter:
# Looks for [1], [2], [12], etc in a string
index_regex: re.Pattern = re.compile('\[\d+\]')
@staticmethod
def _preprocess(val: T) -> T:
"""
Convert some values that are not accepted by mongodb:
- decimal.Decimal
Changes the given collection in place but also returns it
"""
if isinstance(val, dict):
for k, v in val.items():
val[k] = _PynamoDBExpressionsConverter._preprocess(v)
return val
if isinstance(val, list):
for i, v in enumerate(val):
val[i] = _PynamoDBExpressionsConverter._preprocess(v)
return val
if isinstance(val, decimal.Decimal):
return float(val)
return val
@staticmethod
def value_to_raw(value: Value) -> Union[str, dict, list, int, float]:
"""
PynamoDB operand Value contains only one element in a list. This
element is a dict: {'pynamo type': 'value'}
:param value:
:return:
"""
val = DynamoDBJsonSerializer.deserializer.deserialize(value.value)
# now we can return the val, BUT some its values (top-level and nested)
# can contain decimal.Decimal which is not acceptable by MongoDB.
# we should convert them to simple floats.
return _PynamoDBExpressionsConverter._preprocess(val)
@classmethod
def path_to_raw(cls, path: Path) -> str:
"""
You can query MongoDB by nested attributes (one.two.three) and by
first nested lists (one.two.3). But not deeper, i.e 'one.two.3.four'
won't work.
PynamoDB Path is converted a bit differently: one.two[3] . We
just need to change it to one.two.3
:param path:
:return:
"""
raw = str(path)
for index in re.findall(cls.index_regex, raw):
n = index.strip('[]')
raw = raw.replace(index, f'.{n}')
return raw
class ConditionConverter(_PynamoDBExpressionsConverter):
"""
Converts PynamoDB conditions to MongoDB query map. Supported classes from
`pynamodb.expressions.condition`: Comparison, Between, In, Exists,
NotExists, BeginsWith, Contains, And, Or, Not.
IsType and size are not supported. Add support if you want
"""
comparison_map: Dict[str, str] = {
'>': '$gt',
'<': '$lt',
'>=': '$gte',
'<=': '$lte',
'<>': '$ne'
}
@classmethod
def convert(cls, condition: Condition) -> dict:
op = condition.operator
if op == 'OR':
return {
'$or': [cls.convert(cond) for cond in condition.values]
}
if op == 'AND':
return {
'$and': [cls.convert(cond) for cond in condition.values]
}
if op == 'NOT':
return {
'$nor': [cls.convert(condition.values[0])]
}
if op == 'attribute_exists':
return {
cls.path_to_raw(condition.values[0]): {'$exists': True}
}
if op == 'attribute_not_exists':
return {
cls.path_to_raw(condition.values[0]): {'$exists': False}
}
if op == 'contains':
return {
cls.path_to_raw(condition.values[0]): {
'$regex': cls.value_to_raw(condition.values[1])
}
}
if op == 'IN':
return {
cls.path_to_raw(condition.values[0]): {
'$in': list(
cls.value_to_raw(v) for v in
islice(condition.values, 1, None)
)
}
}
if op == '=':
return {
cls.path_to_raw(condition.values[0]): cls.value_to_raw(
condition.values[1])
}
if op in cls.comparison_map:
_mongo_op = cls.comparison_map[op]
return {
cls.path_to_raw(condition.values[0]): {
_mongo_op: cls.value_to_raw(condition.values[1])
}
}
if op == 'BETWEEN':
return {
cls.path_to_raw(condition.values[0]): {
'$gte': cls.value_to_raw(condition.values[1]),
'$lte': cls.value_to_raw(condition.values[2])
}
}
if op == 'begins_with':
return {
cls.path_to_raw(condition.values[0]): {
'$regex': f'^{cls.value_to_raw(condition.values[1])}'
}
}
raise NotImplementedError(f'Operator: {op} is not supported')
class UpdateExpressionConverter(_PynamoDBExpressionsConverter):
"""
Currently just SetAction and RemoveAction, ListAppend, ListPrepend
are supported, you can implement increment and decrement
"""
@classmethod
def convert(cls, action: Action):
if isinstance(action, SetAction):
path, value = action.values
if isinstance(value, Value):
return {
'$set': {cls.path_to_raw(path): cls.value_to_raw(value)}
}
if isinstance(value,
_ListAppend): # appending from one list to another is not supported. However, Dynamo seems to support it
if isinstance(value.values[0], Path): # append
return {
'$push': {cls.path_to_raw(path): {
'$each': cls.value_to_raw(value.values[1])}
}
}
else: # prepend
return {
'$push': {
cls.path_to_raw(path): {
'$each': cls.value_to_raw(value.values[0]),
'$position': 0
},
}
}
# does not work, but the idea is right.
# Only need to make right mongo query
# if isinstance(value, _Increment):
# return {
# '$set': {cls.path_to_raw(path): {
# '$add': [f'${cls.path_to_raw(value.values[0])}', int(cls.value_to_raw(value.values[1]))] # make sure it's int, it is your responsibility
# }}
# }
# if isinstance(value, _Decrement):
# return {
# '$set': {cls.path_to_raw(path): {
# '$add': [f'${cls.path_to_raw(value.values[0])}', -int(cls.value_to_raw(value.values[1]))] # make sure it's int, it is your responsibility
# }}
# }
raise NotImplementedError(
f'Operand of type: {value.__class__.__name__} not supported'
)
if isinstance(action, RemoveAction):
path, = action.values
return {
'$unset': {cls.path_to_raw(path): ""}
# empty string does not matter https://www.mongodb.com/docs/manual/reference/operator/update/unset/#mongodb-update-up.-unset
}
raise NotImplementedError(
f'Action {action.__class__.__name__} is not implemented'
)
class PynamoDBToPyMongoAdapter:
def __init__(self, mongodb_connection: MongoDBConnection):
self.mongodb = mongodb_connection
def batch_get(self, model_class, items, attributes_to_get=None):
collection = self._collection_from_model(model_class)
query_params = []
hash_key_name, range_key_name = self.__get_table_keys(model_class)
if isinstance(items[0], tuple):
query_params.append({'$or': [{hash_key_name: item[0],
range_key_name: item[1]}
for item in items]})
else:
query_params.append(
{'$or': [{hash_key_name: item} for item in items]})
raw_items = collection.find(*query_params)
return [model_class.from_json(
model_json=self.mongodb.decode_keys(item),
attributes_to_get=attributes_to_get)
for item in raw_items]
def delete(self, model_instance):
collection = self._collection_from_model(model_instance)
query = {}
try:
query = model_instance.get_keys()
except AttributeError:
if isinstance(model_instance.attribute_values, dict):
query = model_instance.attribute_values
collection.delete_one(query)
def save(self, model_instance):
json_to_save = model_instance.dynamodb_model()
collection = self._collection_from_model(model_instance)
json_to_save.pop('mongo_id', None)
encoded_document = self.mongodb.encode_keys(
{
key: value for key, value in json_to_save.items()
if value is not None
}
)
collection.replace_one(model_instance.get_keys(),
encoded_document, upsert=True)
def update(self, model_instance, actions: List[Action],
condition: Optional[Condition] = None,
settings: OperationSettings = OperationSettings.default):
collection = self._collection_from_model(model_instance)
_update = {}
for dct in [UpdateExpressionConverter.convert(a) for a in actions]:
for action, query in dct.items():
_update.setdefault(action, {}).update(query)
res = collection.find_one_and_update(
filter=model_instance.get_keys(),
update=_update,
upsert=True,
return_document=ReturnDocument.AFTER
)
if res:
type(model_instance).from_json(res, instance=model_instance)
def get(self, model_class, hash_key, range_key=None) -> Model:
result = self.get_nullable(model_class=model_class,
hash_key=hash_key,
sort_key=range_key)
if not result:
raise model_class.DoesNotExist()
return result
def get_nullable(self, model_class, hash_key, sort_key=None
) -> Optional[Model]:
hash_key_name, range_key_name = self.__get_table_keys(model_class)
if not hash_key_name:
raise AssertionError('Can not identify the hash key name of '
f'model: \'{type(model_class).__name__}\'')
if sort_key and not range_key_name:
raise AssertionError(
f'The range key value is specified for '
f'model \'{type(model_class).__name__}\' but there is no '
f'attribute in the model marked as range_key')
collection = self._collection_from_model(model_class)
params = {hash_key_name: hash_key}
if range_key_name and sort_key:
params[range_key_name] = sort_key
raw_item = collection.find_one(params)
if raw_item:
raw_item = self.mongodb.decode_keys(raw_item)
return model_class.from_json(raw_item)
def query(self, model_class, hash_key, range_key_condition=None,
filter_condition=None, limit=None, last_evaluated_key=None,
attributes_to_get=None, scan_index_forward=True):
# works both for Model and Index
hash_key_name = getattr(model_class._hash_key_attribute(),
'attr_name', None)
range_key_name = getattr(model_class._range_key_attribute(),
'attr_name', None)
if issubclass(model_class, indexes.Index):
model_class = model_class.Meta.model
collection = self._collection_from_model(model_class)
_query = {hash_key_name: hash_key}
if range_key_condition is not None:
_query.update(ConditionConverter.convert(range_key_condition))
if filter_condition is not None:
_query.update(ConditionConverter.convert(filter_condition))
limit = limit or 0 # ZERO means no limit
last_evaluated_key = last_evaluated_key or 0
cursor = collection.find(_query).limit(limit).skip(last_evaluated_key)
if range_key_name:
cursor = cursor.sort(
range_key_name, ASCENDING if scan_index_forward else
DESCENDING
)
return Result(
result=(model_class.from_json(self.mongodb.decode_keys(i),
attributes_to_get) for i in cursor),
_evaluated_key=last_evaluated_key,
page_size=collection.count_documents(_query)
)
def scan(self, model_class, filter_condition=None, limit=None,
last_evaluated_key=None, attributes_to_get=None):
collection = self._collection_from_model(model_class)
_query = {}
if filter_condition is not None:
_query.update(ConditionConverter.convert(filter_condition))
limit = limit or 0 # ZERO means no limit
last_evaluated_key = last_evaluated_key or 0
cursor = collection.find(_query).limit(limit).skip(last_evaluated_key)
return Result(
result=(model_class.from_json(self.mongodb.decode_keys(i),
attributes_to_get) for i in cursor),
_evaluated_key=last_evaluated_key,
page_size=collection.count_documents(_query)
)
def refresh(self, consistent_read):
raise NotImplementedError
def _collection_from_model(self, model: Model) -> Collection:
name = model.Meta.table_name
return self.mongodb.collection(collection_name=name)
def count(self, model_class, hash_key=None,
range_key_condition=None,
filter_condition=None,
index_name=None,
limit=None) -> int:
collection = self._collection_from_model(model_class)
hash_key_name = getattr(model_class._hash_key_attribute(),
'attr_name', None)
if index_name:
hash_key_name = getattr(
model_class._indexes[index_name]._hash_key_attribute(),
'attr_name', None
)
_query = {hash_key_name: hash_key}
if range_key_condition is not None:
_query.update(ConditionConverter.convert(range_key_condition))
if filter_condition is not None:
_query.update(ConditionConverter.convert(filter_condition))
if limit:
return collection.count_documents(_query, limit=limit)
return collection.count_documents(_query)
def batch_write(self, model_class) -> BatchWrite:
return BatchWrite(model=model_class, mongo_connection=self.mongodb)
@staticmethod
def __get_table_keys(model_class) -> tuple:
short_to_body_mapping = {attr_body.attr_name: attr_body
for attr_name, attr_body in
model_class._attributes.items()}
hash_key_name = None
range_key_name = None
for short_name, body in short_to_body_mapping.items():
if body.is_hash_key:
hash_key_name = short_name
continue
if body.is_range_key:
range_key_name = short_name
continue
return hash_key_name, range_key_name