src/Epam.GraphQL/Helpers/ExpressionHelpers.cs (363 lines of code) (raw):

// Copyright © 2020 EPAM Systems, Inc. All Rights Reserved. All information contained herein is, and remains the // property of EPAM Systems, Inc. and/or its suppliers and is protected by international intellectual // property law. Dissemination of this information or reproduction of this material is strictly forbidden, // unless prior written permission is obtained from EPAM Systems, Inc using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Linq.Expressions; using System.Reflection; using Epam.GraphQL.Extensions; namespace Epam.GraphQL.Helpers { internal static class ExpressionHelpers { public static Expression<Func<T1, TResult>> Compose<T1, T2, TResult>(Expression<Func<T1, T2>> first, Expression<Func<T2, TResult>> second) { var paramForReplace = first.Parameters[0]; var param = Expression.Parameter(paramForReplace.Type, paramForReplace.Name); var body = ParameterRebinder<Expression, Expression>.ReplaceParameter(first.Body, paramForReplace, param); return Expression.Lambda<Func<T1, TResult>>(ParameterRebinder<Expression, Expression>.ReplaceParameter(second.Body, second.Parameters[0], body), param); } public static LambdaExpression Compose(LambdaExpression first, LambdaExpression second) { var paramForReplace = first.Parameters[0]; var param = Expression.Parameter(paramForReplace.Type, paramForReplace.Name); var body = ParameterRebinder<Expression, Expression>.ReplaceParameter(first.Body, paramForReplace, param); return Expression.Lambda(ParameterRebinder<Expression, Expression>.ReplaceParameter(second.Body, second.Parameters[0], body), param); } public static Expression<Func<T, T1, TResult>> Compose<T, T1, T2, TResult>(Expression<Func<T, T1, T2>> first, Expression<Func<T2, TResult>> second) { var paramForReplace1 = first.Parameters[0]; var param1 = Expression.Parameter(paramForReplace1.Type, paramForReplace1.Name); var paramForReplace2 = first.Parameters[1]; var param2 = Expression.Parameter(paramForReplace2.Type, paramForReplace2.Name); var body = ParameterRebinder<Expression, Expression>.ReplaceParameters( new Dictionary<Expression, Expression> { [paramForReplace1] = param1, [paramForReplace2] = param2, }, first.Body); return Expression.Lambda<Func<T, T1, TResult>>(ParameterRebinder<Expression, Expression>.ReplaceParameter(second.Body, second.Parameters[0], body), param1, param2); } public static Expression<Func<T, T1, TResult>> Compose<T, T1, T2, TResult>(Expression<Func<T, T1, T2>> first, Expression<Func<T, T2, TResult>> second) { var paramForReplace1 = first.Parameters[0]; var param1 = Expression.Parameter(paramForReplace1.Type, paramForReplace1.Name); var paramForReplace2 = first.Parameters[1]; var param2 = Expression.Parameter(paramForReplace2.Type, paramForReplace2.Name); var body = ParameterRebinder<Expression, Expression>.ReplaceParameters( new Dictionary<Expression, Expression> { [paramForReplace1] = param1, [paramForReplace2] = param2, }, first.Body); return Expression.Lambda<Func<T, T1, TResult>>(ParameterRebinder<Expression, Expression>.ReplaceParameter(second.Body, second.Parameters[1], body), param1, param2); } public static Expression<Func<T, T1, TResult>> SafeCompose<T, T1, T2, TResult>(Expression<Func<T, T1, T2>> first, Expression<Func<T, T2, TResult>> second) { if ((!typeof(T2).IsValueType || typeof(T2).IsNullable()) && (!typeof(TResult).IsValueType || typeof(TResult).IsNullable())) { var paramForReplace1 = first.Parameters[0]; var param1 = Expression.Parameter(paramForReplace1.Type, paramForReplace1.Name); var paramForReplace2 = first.Parameters[1]; var param2 = Expression.Parameter(paramForReplace2.Type, paramForReplace2.Name); var body = ParameterRebinder<Expression, Expression>.ReplaceParameters( new Dictionary<Expression, Expression> { [paramForReplace1] = param1, [paramForReplace2] = param2, }, first.Body); var secondBody = ParameterRebinder<Expression, Expression>.ReplaceParameters( new Dictionary<Expression, Expression>(ExpressionEqualityComparer.Instance) { [body] = body, [second.Parameters[1]] = body, }, second.Body); var testExpr = Expression.Equal(body, Expression.Constant(null, body.Type)); var conditionExpr = Expression.Condition(testExpr, Expression.Constant(null, secondBody.Type), secondBody); return Expression.Lambda<Func<T, T1, TResult>>(conditionExpr, param1, param2); } return Compose(first, second); } public static Expression<Func<TItem, bool>> MakeContainsExpression<TItem, TId>(IEnumerable<TId> ids, Expression<Func<TItem, TId>> keySelector) { var value = Tuple.Create(ids); var valueConstExpression = Expression.Constant(value); var idExpression = Expression.Property(valueConstExpression, CachedReflectionInfo.ForTuple<IEnumerable<TId>>.Item1); var paramExpression = Expression.Parameter(typeof(TItem), keySelector.Parameters[0].Name); var keySelectorExpression = ParameterRebinder<Expression, ParameterExpression>.ReplaceParameter(keySelector.Body, keySelector.Parameters[0], paramExpression); var callExpression = Expression.Call(CachedReflectionInfo.ForEnumerable.Contains<TId>(), idExpression, keySelectorExpression); return Expression.Lambda<Func<TItem, bool>>(callExpression, paramExpression); } public static Expression<Func<TItem, bool>> MakeContainsExpression<TItem, TId>(IEnumerable<TId> ids, Expression<Func<TItem, TId?>> keySelector) where TId : struct { var value = Tuple.Create(ids); var valueConstExpression = Expression.Constant(value); var idExpression = Expression.Property(valueConstExpression, CachedReflectionInfo.ForTuple<IEnumerable<TId>>.Item1); var paramExpression = Expression.Parameter(typeof(TItem), keySelector.Parameters[0].Name); var keySelectorExpression = ParameterRebinder<Expression, ParameterExpression>.ReplaceParameter(keySelector.Body, keySelector.Parameters[0], paramExpression); var valueAccessExpression = Expression.Property(keySelectorExpression, CachedReflectionInfo.ForNullable<TId>.Value); var hasValueAccessExpression = Expression.Property(keySelectorExpression, CachedReflectionInfo.ForNullable<TId>.HasValue); var callExpression = Expression.Call(CachedReflectionInfo.ForEnumerable.Contains<TId>(), idExpression, valueAccessExpression); var andExpression = Expression.AndAlso(hasValueAccessExpression, callExpression); return Expression.Lambda<Func<TItem, bool>>(andExpression, paramExpression); } public static Expression<Func<TEntity, TProperty>> MakeValueAccessExpression<TEntity, TProperty>(Expression<Func<TEntity, TProperty?>> selector) where TProperty : struct { return ValueAccessExpressionCacher<TEntity, TProperty>.MakeValueAccessExpression(selector); } public static Expression<Func<object?, object?>> MakeWeakLambdaExpression(LambdaExpression selector) { selector.ShouldHaveOnlyOneParameter(nameof(selector)); var paramExpression = Expression.Parameter(typeof(object), selector.Parameters[0].Name); var convertParamExpression = Expression.Convert(paramExpression, selector.Parameters[0].Type); var keySelectorExpression = ParameterRebinder<Expression, UnaryExpression>.ReplaceParameter(selector.Body, selector.Parameters[0], convertParamExpression); var convertResultExpression = Expression.Convert(keySelectorExpression, typeof(object)); var result = Expression.Lambda<Func<object?, object?>>(convertResultExpression, paramExpression); return result; } public static LambdaExpression MakeIdentity(Type type) { var param = Expression.Parameter(type); return Expression.Lambda(param, param); } public static ExpressionFactorizationResult Factorize<T1, T2, TResult>(Expression<Func<T1, T2, TResult>> expression) { return FactorizationVisitor.Factorize(expression); } public static ConditionFactorizationResult<T2> FactorizeCondition<T1, T2>(Expression<Func<T1, T2, bool>> condition) { if (TryFactorizeCondition(condition, out var result)) { return result!; } throw new ArgumentOutOfRangeException(nameof(condition), $"Cannot use expression {condition} as a relation between {typeof(T1).HumanizedName()} and {typeof(T2).HumanizedName()} types."); } public static bool TryFactorizeCondition<T1, T2>(Expression<Func<T1, T2, bool>> expression, [NotNullWhen(true)] out ConditionFactorizationResult<T2>? result) { var factorizationResult = Factorize(expression); if (factorizationResult.LeftExpressions.Count == 1 && factorizationResult.RightExpressions.Count != 0) { var leftExpression = factorizationResult.LeftExpressions[0]; LambdaExpression? rightExpression = null; Expression<Func<T2, bool>>? rightCondition = null; var operands = GetAndAlsoExpressions(factorizationResult.Expression.Body); var parameters = factorizationResult.Expression.Parameters; var leftParam = parameters[0]; var equalOperand = operands.Single(op => op.ContainsParameter(leftParam)); if (equalOperand.NodeType == ExpressionType.Equal) { var index = -1; for (int i = 0; i < parameters.Count - 1; i++) { if (equalOperand.ContainsParameter(parameters[i + 1])) { index = i; break; } } if (index != -1) { var rightExpressions = factorizationResult.RightExpressions; Expression? left = null; var param = Expression.Parameter(typeof(T2)); for (int i = 0; i < rightExpressions.Count; i++) { if (i != index) { var right = ParameterRebinder<Expression, ParameterExpression>.ReplaceParameter(rightExpressions[i].Body, rightExpressions[i].Parameters[0], param); if (left == null) { left = right; continue; } left = Expression.AndAlso(left, right); } else { rightExpression = rightExpressions[i]; } } if (left != null) { rightCondition = Expression.Lambda<Func<T2, bool>>(left, param); } if (rightExpression != null) { result = new ConditionFactorizationResult<T2>(leftExpression, rightExpression, rightCondition); return true; } } } } result = null; return false; } public static MemberInitBuilder<TResult> MakeMemberInit<TResult>(Type paramType) { return new MemberInitBuilder<TResult>(paramType); } private static IReadOnlyList<Expression> GetAndAlsoExpressions(Expression expression) { var expressions = new List<Expression>(); while (expression is BinaryExpression binaryExpression && expression.NodeType == ExpressionType.AndAlso) { expressions.Add(binaryExpression.Right); expression = binaryExpression.Left; } expressions.Add(expression); expressions.Reverse(); return expressions; } public class MemberInitBuilder<TResult> { private readonly Dictionary<PropertyInfo, LambdaExpression> _assignments = new(); private readonly ParameterExpression _param; public MemberInitBuilder(Type paramType) { _param = Expression.Parameter(paramType); } public MemberInitBuilder<TResult> Property(PropertyInfo property, LambdaExpression initializer) { if (_assignments.TryGetValue(property, out var existingInitializer)) { if (!ExpressionEqualityComparer.Instance.Equals(initializer, existingInitializer)) { throw new InvalidOperationException($"Attempt to create two different initilizers ({existingInitializer} and {initializer})for a property {property.Name} of type {typeof(TResult)}."); } return this; } _assignments.Add(property, initializer); return this; } public LambdaExpression Lambda() { var ctor = typeof(TResult).GetConstructors().Single(c => c.GetParameters().Length == 0); var newExpr = Expression.New(ctor); var memberAssignments = _assignments.Select(kv => Expression.Bind(kv.Key, kv.Value.Body.ReplaceParameter(kv.Value.Parameters[0], _param))); var initExpr = Expression.MemberInit(newExpr, memberAssignments); return Expression.Lambda(initExpr, _param); } } public class ParameterRebinder<TKey, T> : ExpressionVisitor where T : Expression where TKey : Expression { private readonly IReadOnlyDictionary<TKey, T> _map; protected ParameterRebinder(IReadOnlyDictionary<TKey, T> map) { _map = map; } public static Expression ReplaceParameters(IReadOnlyDictionary<TKey, T> map, Expression exp) { return new ParameterRebinder<TKey, T>(map).Visit(exp); } public static Expression ReplaceParameter( Expression expression, TKey parameterExpression, T newExpression) => ReplaceParameters( new Dictionary<TKey, T> { [parameterExpression] = newExpression, }, expression); public override Expression Visit(Expression p) { if (p is TKey key && _map.TryGetValue(key, out var replacement)) { return replacement; } return base.Visit(p); } } private static class ValueAccessExpressionCacher<TEntity, TProperty> where TProperty : struct { private static readonly ConcurrentDictionary<Expression<Func<TEntity, TProperty?>>, Expression<Func<TEntity, TProperty>>> _cache = new( ExpressionEqualityComparer.Instance); public static Expression<Func<TEntity, TProperty>> MakeValueAccessExpression(Expression<Func<TEntity, TProperty?>> selector) { return _cache.GetOrAdd(selector, selector => { var paramExpression = Expression.Parameter(typeof(TEntity), selector.Parameters[0].Name); var selectorExpression = ParameterRebinder<Expression, Expression>.ReplaceParameter(selector.Body, selector.Parameters[0], paramExpression); var valueAccessExpression = Expression.Property(selectorExpression, CachedReflectionInfo.ForNullable<TProperty>.Value); var result = Expression.Lambda<Func<TEntity, TProperty>>(valueAccessExpression, paramExpression); return result; }); } } private class FactorizationVisitor : ExpressionVisitor { private readonly List<Expression> _leftExpressions = new(); private readonly List<Expression> _rightExpressions = new(); private readonly IReadOnlyList<ParameterExpression> _parameters; private readonly List<ParameterExpression> _leftParameters = new(); private readonly List<ParameterExpression> _rightParameters = new(); private FactorizationVisitor(IReadOnlyCollection<ParameterExpression> parameters) { _parameters = parameters.ToList(); } public static ExpressionFactorizationResult Factorize<T1, T2, TResult>(Expression<Func<T1, T2, TResult>> expression) { var visitor = new FactorizationVisitor(expression.Parameters); var resultExpression = visitor.Visit(expression.Body); var result = new ExpressionFactorizationResult( leftExpressions: visitor._leftExpressions.Select(e => Expression.Lambda(e, visitor._parameters[0])).ToList(), rightExpressions: visitor._rightExpressions.Select(e => Expression.Lambda(e, visitor._parameters[1])).ToList(), expression: Expression.Lambda(resultExpression, visitor._leftParameters.Concat(visitor._rightParameters))); return result; } public override Expression Visit(Expression node) { var parameters = ParameterVisitor.GetParameters(node, _parameters); if (parameters.Count == 1) { if (parameters[0] == _parameters[0]) { var parameter = Expression.Parameter(node.Type, $"Left_{_leftParameters.Count}"); _leftParameters.Add(parameter); _leftExpressions.Add(node); return parameter; } else { var parameter = Expression.Parameter(node.Type, $"Right_{_rightParameters.Count}"); _rightParameters.Add(parameter); _rightExpressions.Add(node); return parameter; } } return base.Visit(node); } } private class ParameterVisitor : ExpressionVisitor { private readonly HashSet<ParameterExpression> _parameters = new(); private readonly HashSet<ParameterExpression> _allowedParameters; private ParameterVisitor(IEnumerable<ParameterExpression> allowedParameters) { _allowedParameters = new HashSet<ParameterExpression>(allowedParameters); } public static IReadOnlyList<ParameterExpression> GetParameters(Expression expression, IEnumerable<ParameterExpression> allowedParameters) { var visitor = new ParameterVisitor(allowedParameters); visitor.Visit(expression); return visitor._parameters.ToList(); } protected override Expression VisitParameter(ParameterExpression node) { if (_allowedParameters.Contains(node)) { _parameters.Add(node); } return base.VisitParameter(node); } } } }