I am following an example series on MSDN for creating a LINQ Provider and have hit a wall.
I am expecting that when I write the following test that the ExpressionVisitor sub-class in the source-code below has it's VisitMethodCall invoked.
[Fact]
public void DatabaseModeler_provides_table_modeler()
{
var q = new LightmapQuery<AspNetRoles>(new SqliteProvider2());
q.Where(role => role.Name == "Admin");
var result = q.ToList();
}
What happens instead is that the VisitConstant method is invoked. I assume this is because when the Provider is instanced, it assigns it's Expression property a ConstantExpression. I'm not sure if I am doing something wrong or if the guide on MSDN has issues with it preventing me from getting the expression containing the Where method invocation.
This is the source code I have for the IQueryable<T> and IQueryProvider implementations.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
namespace Lightmap.Querying
{
internal class SqliteQueryTranslator : ExpressionVisitor
{
internal StringBuilder queryBuilder;
internal string Translate(Expression expression)
{
this.queryBuilder = new StringBuilder();
this.Visit(expression);
return this.queryBuilder.ToString();
}
public override Expression Visit(Expression node)
{
return base.Visit(node);
}
protected override Expression VisitMethodCall(MethodCallExpression node)
{
if (node.Method.DeclaringType != typeof(IQueryable) && node.Method.Name != nameof(Enumerable.Where))
{
throw new NotSupportedException($"The {node.Method.Name} method is not supported.");
}
this.queryBuilder.Append($"SELECT * FROM {node.Method.DeclaringType.Name}");
return node;
}
protected override Expression VisitConstant(ConstantExpression node)
{
return node;
}
private static Expression StripQuotes(Expression expression)
{
while (expression.NodeType == ExpressionType.Quote)
{
expression = ((UnaryExpression)expression).Operand;
}
return expression;
}
}
public abstract class LightmapProvider : IQueryProvider
{
public IQueryable CreateQuery(Expression expression)
{
Type genericParameter = expression.GetType().GetGenericArguments().First();
return (IQueryable)Activator.CreateInstance(typeof(LightmapQuery<>)
.MakeGenericType(genericParameter), new object[] { this, expression });
}
public IQueryable<TElement> CreateQuery<TElement>(Expression expression) => new LightmapQuery<TElement>(this, expression);
object IQueryProvider.Execute(Expression expression)
{
return this.Execute(expression);
}
public TResult Execute<TResult>(Expression expression)
{
return (TResult)this.Execute(expression);
}
public abstract string GetQueryText(Expression expression);
public abstract object Execute(Expression expression);
}
public class SqliteProvider2 : LightmapProvider
{
public override object Execute(Expression expression)
{
var x = new SqliteQueryTranslator().Translate(expression);
return Activator.CreateInstance(typeof(List<>).MakeGenericType(TypeCache.GetGenericParameter(expression.Type, t => true)));
}
public override string GetQueryText(Expression expression)
{
throw new NotImplementedException();
}
}
public class LightmapQuery<TTable> : IOrderedQueryable<TTable>
{
public LightmapQuery(IQueryProvider provider)
{
this.Provider = provider;
this.Expression = Expression.Constant(this);
}
public LightmapQuery(IQueryProvider provider, Expression expression)
{
if (!typeof(IQueryable<TTable>).IsAssignableFrom(expression.Type))
{
throw new Exception();
}
this.Expression = expression;
this.Provider = provider;
}
public Type ElementType => typeof(TTable);
public Expression Expression { get; }
public IQueryProvider Provider { get; }
public IEnumerator<TTable> GetEnumerator()
{
return (this.Provider.Execute<IEnumerable<TTable>>(Expression)).GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return (Provider.Execute<IEnumerable>(Expression)).GetEnumerator();
}
}
}
Edit
Having updated my unit test to actually store the IQueryable from the .Where, I'm still not seeing my VisitMethodCall invoked.
[Fact]
public void DatabaseModeler_provides_table_modeler()
{
var q = new LightmapQuery<AspNetRoles>(new SqliteProvider2())
.Where(role => role.Name == "Admin");
var result = q.ToList();
}
After some trial and error and a bit of help in IRC, I was able to identify the problem and resolve it. The issue was that my unit test project, which was a .Net Core project, did not have a reference to the System.Linq.Queryable assembly.
In my unit test
[Fact]
public void DatabaseModeler_provides_table_modeler()
{
var q = new LightmapQuery<AspNetRoles>(new SqliteProvider2());
var q2 = q.Where(role => role.Name == "Admin");
var result = q2.ToList();
}
The Type of q2 is "WhereEnumerableIterator'1" which made us realize that it wasn't returning an IQueryable.
After adding the above reference to the project.json, the Type of q2 turned into "LightmapQuery'1" like I was expecting. With this, my VisitMethodCall method gets hit without issue.
The issue is trivial - you forgot to assign the result of applying Where:
q.Where(role => role.Name == "Admin");
use something like this instead
var q = new LightmapQuery<AspNetRoles>(new SqliteProvider2())
.Where(role => role.Name == "Admin");
var result = q.ToList();
Related
I want to remove any cast expression in an expression tree. We can assume that the cast is redundant.
For example, both these expressions:
IFoo t => ((t as Foo).Bar as IExtraBar).Baz;
IFoo t => ((IExtraBar)(t as Foo).Bar).Baz;
become this:
IFoo t => t.Bar.Baz
How can you accomplish this?
Sample code
The sample below illustrates a pretty simple scenario. However, my coded fails with an exception:
Unhandled Exception: System.ArgumentException: Property 'IBar Bar' is not defined for type 'ExpressionTest.Program+IFoo'
using System;
using System.Linq.Expressions;
namespace ExpressionTest
{
class Program
{
public interface IBar
{
string Baz { get; }
}
public interface IExtraBar : IBar
{
}
public interface IFoo
{
IBar Bar { get; }
}
public class Foo : IFoo
{
public IBar Bar { get; }
}
static void Main(string[] args)
{
Expression<Func<IFoo, string>> expr = t => ((t as Foo).Bar as IExtraBar).Baz;
Expression<Func<IFoo, string>> expr2 = t => ((IExtraBar)(t as Foo).Bar).Baz;
// Wanted: IFoo t => t.Bar.Baz
var visitor = new CastRemoverVisitor();
visitor.Visit(expr);
Console.WriteLine(visitor.Expression.ToString());
}
public class CastRemoverVisitor : ExpressionVisitor
{
public Expression Expression { get; private set; }
public override Expression Visit(Expression node)
{
Expression ??= node;
return base.Visit(node);
}
protected override Expression VisitUnary(UnaryExpression node)
{
Expression = node.Operand;
return Visit(node.Operand);
}
}
}
}
Final solution
The accepted answer pinpoints the use of Expression.MakeMemberAccess and some interface tricks. We can "improve" the code a bit, to use an interfaced PropertyInfo instead of going through the interfaced getter. I ended up with the following:
public class CastRemoverVisitor : ExpressionVisitor
{
protected override Expression VisitUnary(UnaryExpression node)
{
return node.IsCastExpression() ? Visit(node.Operand) : base.VisitUnary(node);
}
protected override Expression VisitMember(MemberExpression node)
{
if (node.Expression is UnaryExpression unaryExpression &&
unaryExpression.IsCastExpression())
{
var propertyInfo = node.Member.ToInterfacedProperty();
if (propertyInfo != null)
{
return base.Visit(
Expression.MakeMemberAccess(
unaryExpression.Operand,
propertyInfo
));
}
}
return base.VisitMember(node);
}
}
// And some useful extension methods...
public static class MemberInfoExtensions
{
public static MemberInfo ToInterfacedProperty(this MemberInfo member)
{
var interfaces = member.DeclaringType!.GetInterfaces();
var mi = interfaces.Select(i => i.GetProperty(member.Name))
.FirstOrDefault(p => p != null);
return mi;
}
}
public static class ExpressionExtensions
{
public static bool IsCastExpression(this Expression expression) =>
expression.NodeType == ExpressionType.TypeAs ||
expression.NodeType == ExpressionType.Convert;
}
And then we use it like this:
var visitor = new CastRemoverVisitor();
var cleanExpr = visitor.Visit(expr);
Console.WriteLine(cleanExpr.ToString());
First off, nice repro. Thank you.
The problem is that the property access (t as Foo).Bar is calling the getter for Foo.Bar, and not the getter for IFoo.Bar (yes, those are different things with different MethodInfos).
You can see this by overriding VisitMember and see the MethodInfo being passed.
However, an approach like this seems to work. We have to unwrap things at the point of the member access, since we can only proceed if we can find an equivalent member to access on the uncasted type:
public class CastRemoverVisitor : ExpressionVisitor
{
protected override Expression VisitMember(MemberExpression node)
{
if (node.Expression is UnaryExpression { NodeType: ExpressionType.TypeAs or ExpressionType.Convert, Operand: var operand } &&
node.Member is PropertyInfo propertyInfo &&
operand.Type.IsInterface)
{
// Is this just inheriting a type from a base interface?
// Get rid of the cast, and just call the property on the uncasted member
if (propertyInfo.DeclaringType == operand.Type)
{
return base.Visit(Expression.MakeMemberAccess(operand, propertyInfo));
}
// Is node.Expression a concrete type, which implements this interface method?
var methodInfo = GetInterfaceMethodInfo(operand.Type, node.Expression.Type, propertyInfo.GetMethod);
if (methodInfo != null)
{
return base.Visit(Expression.Call(operand, methodInfo));
}
}
return base.VisitMember(node);
}
private static MethodInfo GetInterfaceMethodInfo(Type interfaceType, Type implementationType, MethodInfo implementationMethodInfo)
{
if (!implementationType.IsClass)
return null;
var map = implementationType.GetInterfaceMap(interfaceType);
for (int i = 0; i < map.InterfaceMethods.Length; i++)
{
if (map.TargetMethods[i] == implementationMethodInfo)
{
return map.InterfaceMethods[i];
}
}
return null;
}
}
I'm sure there cases which will break this (fields come to mind, and I know GetInterfaceMap doesn't play well with generics in some situations), but it's a starting point.
The line price = co?.price ?? 0, in the following code gives me the above error, but if I remove ? from co.? it works fine.
I was trying to follow this MSDN example where they are using ? on line select new { person.FirstName, PetName = subpet?.Name ?? String.Empty }; So, it seems I need to understand when to use ? with ?? and when not to.
Error:
an expression tree lambda may not contain a null propagating operator
public class CustomerOrdersModelView
{
public string CustomerID { get; set; }
public int FY { get; set; }
public float? price { get; set; }
....
....
}
public async Task<IActionResult> ProductAnnualReport(string rpt)
{
var qry = from c in _context.Customers
join ord in _context.Orders
on c.CustomerID equals ord.CustomerID into co
from m in co.DefaultIfEmpty()
select new CustomerOrdersModelView
{
CustomerID = c.CustomerID,
FY = c.FY,
price = co?.price ?? 0,
....
....
};
....
....
}
The example you were quoting from uses LINQ to Objects, where the implicit lambda expressions in the query are converted into delegates... whereas you're using EF or similar, with IQueryable<T> queryies, where the lambda expressions are converted into expression trees. Expression trees don't support the null conditional operator (or tuples).
Just do it the old way:
price = co == null ? 0 : (co.price ?? 0)
(I believe the null-coalescing operator is fine in an expression tree.)
The code you link to uses List<T>. List<T> implements IEnumerable<T> but not IQueryable<T>. In that case, the projection is executed in memory and ?. works.
You're using some IQueryable<T>, which works very differently. For IQueryable<T>, a representation of the projection is created, and your LINQ provider decides what to do with it at runtime. For backwards compatibility reasons, ?. cannot be used here.
Depending on your LINQ provider, you may be able to use plain . and still not get any NullReferenceException.
Jon Skeet's answer was right, in my case I was using DateTime for my Entity class.
When I tried to use like
(a.DateProperty == null ? default : a.DateProperty.Date)
I had the error
Property 'System.DateTime Date' is not defined for type 'System.Nullable`1[System.DateTime]' (Parameter 'property')
So I needed to change DateTime? for my entity class and
(a.DateProperty == null ? default : a.DateProperty.Value.Date)
While expression tree does not support the C# 6.0 null propagating, what we can do is create a visitor that modify expression tree for safe null propagation, just like the operator does!
Here is mine:
public class NullPropagationVisitor : ExpressionVisitor
{
private readonly bool _recursive;
public NullPropagationVisitor(bool recursive)
{
_recursive = recursive;
}
protected override Expression VisitUnary(UnaryExpression propertyAccess)
{
if (propertyAccess.Operand is MemberExpression mem)
return VisitMember(mem);
if (propertyAccess.Operand is MethodCallExpression met)
return VisitMethodCall(met);
if (propertyAccess.Operand is ConditionalExpression cond)
return Expression.Condition(
test: cond.Test,
ifTrue: MakeNullable(Visit(cond.IfTrue)),
ifFalse: MakeNullable(Visit(cond.IfFalse)));
return base.VisitUnary(propertyAccess);
}
protected override Expression VisitMember(MemberExpression propertyAccess)
{
return Common(propertyAccess.Expression, propertyAccess);
}
protected override Expression VisitMethodCall(MethodCallExpression propertyAccess)
{
if (propertyAccess.Object == null)
return base.VisitMethodCall(propertyAccess);
return Common(propertyAccess.Object, propertyAccess);
}
private BlockExpression Common(Expression instance, Expression propertyAccess)
{
var safe = _recursive ? base.Visit(instance) : instance;
var caller = Expression.Variable(safe.Type, "caller");
var assign = Expression.Assign(caller, safe);
var acess = MakeNullable(new ExpressionReplacer(instance,
IsNullableStruct(instance) ? caller : RemoveNullable(caller)).Visit(propertyAccess));
var ternary = Expression.Condition(
test: Expression.Equal(caller, Expression.Constant(null)),
ifTrue: Expression.Constant(null, acess.Type),
ifFalse: acess);
return Expression.Block(
type: acess.Type,
variables: new[]
{
caller,
},
expressions: new Expression[]
{
assign,
ternary,
});
}
private static Expression MakeNullable(Expression ex)
{
if (IsNullable(ex))
return ex;
return Expression.Convert(ex, typeof(Nullable<>).MakeGenericType(ex.Type));
}
private static bool IsNullable(Expression ex)
{
return !ex.Type.IsValueType || (Nullable.GetUnderlyingType(ex.Type) != null);
}
private static bool IsNullableStruct(Expression ex)
{
return ex.Type.IsValueType && (Nullable.GetUnderlyingType(ex.Type) != null);
}
private static Expression RemoveNullable(Expression ex)
{
if (IsNullableStruct(ex))
return Expression.Convert(ex, ex.Type.GenericTypeArguments[0]);
return ex;
}
private class ExpressionReplacer : ExpressionVisitor
{
private readonly Expression _oldEx;
private readonly Expression _newEx;
internal ExpressionReplacer(Expression oldEx, Expression newEx)
{
_oldEx = oldEx;
_newEx = newEx;
}
public override Expression Visit(Expression node)
{
if (node == _oldEx)
return _newEx;
return base.Visit(node);
}
}
}
It passes on the following tests:
private static string Foo(string s) => s;
static void Main(string[] _)
{
var visitor = new NullPropagationVisitor(recursive: true);
Test1();
Test2();
Test3();
void Test1()
{
Expression<Func<string, char?>> f = s => s == "foo" ? 'X' : Foo(s).Length.ToString()[0];
var fBody = (Expression<Func<string, char?>>)visitor.Visit(f);
var fFunc = fBody.Compile();
Debug.Assert(fFunc(null) == null);
Debug.Assert(fFunc("bar") == '3');
Debug.Assert(fFunc("foo") == 'X');
}
void Test2()
{
Expression<Func<string, int>> y = s => s.Length;
var yBody = visitor.Visit(y.Body);
var yFunc = Expression.Lambda<Func<string, int?>>(
body: yBody,
parameters: y.Parameters)
.Compile();
Debug.Assert(yFunc(null) == null);
Debug.Assert(yFunc("bar") == 3);
}
void Test3()
{
Expression<Func<char?, string>> y = s => s.Value.ToString()[0].ToString();
var yBody = visitor.Visit(y.Body);
var yFunc = Expression.Lambda<Func<char?, string>>(
body: yBody,
parameters: y.Parameters)
.Compile();
Debug.Assert(yFunc(null) == null);
Debug.Assert(yFunc('A') == "A");
}
}
I'm using XUNIT to test in a dot net core application.
I need to test a service that is internally making an async query on a DbSet in my datacontext.
I've seen here that mocking that DbSet asynchronously is possible.
The problem I'm having is that the IDbAsyncQueryProvider does not seem to be available in EntityframeworkCore, which I'm using.
Am I incorrect here? Has anyone else got this working?
(Been a long day, hopefully I'm just missing something simple)
EDIT
After asking on GitHub, I got point to this class:
https://github.com/aspnet/EntityFramework/blob/dev/src/Microsoft.EntityFrameworkCore/Query/Internal/IAsyncQueryProvider.cs
This is what I've gotten to so far in trying to implement this:
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore.Query.Internal;
namespace EFCoreTestQueryProvider
{
internal class TestAsyncQueryProvider<TEntity>: IAsyncQueryProvider
{
private readonly IQueryProvider _inner;
internal TestAsyncQueryProvider(IQueryProvider inner)
{
_inner = inner;
}
IQueryable CreateQuery(Expression expression)
{
return new TestDbAsyncEnumerable<TEntity>(expression);
}
IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
return new TestDbAsyncEnumerable<TElement>(expression);
}
object Execute(Expression expression)
{
return _inner.Execute(expression);
}
TResult Execute<TResult>(Expression expression)
{
return _inner.Execute<TResult>(expression);
}
IAsyncEnumerable<TResult> ExecuteAsync<TResult>(Expression expression)
{
return Task.FromResult(Execute<TResult>(expression)).ToAsyncEnumerable();
}
Task<TResult> IAsyncQueryProvider.ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
{
return Task.FromResult(Execute<TResult>(expression));
}
}
internal class TestDbAsyncEnumerable<T> : EnumerableQuery<T>, System.Collections.Generic.IAsyncEnumerable<T>, IQueryable<T>
{
public TestDbAsyncEnumerable(IEnumerable<T> enumerable)
: base(enumerable)
{ }
public TestDbAsyncEnumerable(Expression expression)
: base(expression)
{ }
public IAsyncEnumerator<T> GetAsyncEnumerator()
{
return new TestDbAsyncEnumerable<T>(this.AsEnumerable()).ToAsyncEnumerable();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
return GetAsyncEnumerator();
}
IAsyncEnumerator<T> IAsyncEnumerable<T>.GetEnumerator()
{
throw new NotImplementedException();
}
IQueryProvider IQueryable.Provider
{
get { return new TestAsyncQueryProvider<T>(this); }
}
}
}
I've now tried to implement this and have run into some more issues, specifically around these two methods:
public IAsyncEnumerator<T> GetAsyncEnumerator()
{
return new TestDbAsyncEnumerable<T>(this.AsEnumerable()).ToAsyncEnumerable();
}
IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
{
return GetAsyncEnumerator();
}
I'm hoping that somebody could point me in the right direction as to what I'm doing wrong.
I finally got this to work. They slightly changed the interfaces in EntityFrameworkCore from IDbAsyncEnumerable to IAsyncEnumerable so the following code worked for me:
public class AsyncEnumerable<T> : EnumerableQuery<T>, IAsyncEnumerable<T>, IQueryable<T>
{
public AsyncEnumerable(Expression expression)
: base(expression) { }
public IAsyncEnumerator<T> GetEnumerator() =>
new AsyncEnumerator<T>(this.AsEnumerable().GetEnumerator());
}
public class AsyncEnumerator<T> : IAsyncEnumerator<T>
{
private readonly IEnumerator<T> enumerator;
public AsyncEnumerator(IEnumerator<T> enumerator) =>
this.enumerator = enumerator ?? throw new ArgumentNullException();
public T Current => enumerator.Current;
public void Dispose() { }
public Task<bool> MoveNext(CancellationToken cancellationToken) =>
Task.FromResult(enumerator.MoveNext());
}
[Fact]
public async Task TestEFCore()
{
var data =
new List<Entity>()
{
new Entity(),
new Entity(),
new Entity()
}.AsQueryable();
var mockDbSet = new Mock<DbSet<Entity>>();
mockDbSet.As<IAsyncEnumerable<Entity>>()
.Setup(d => d.GetEnumerator())
.Returns(new AsyncEnumerator<Entity>(data.GetEnumerator()));
mockDbSet.As<IQueryable<Entity>>().Setup(m => m.Provider).Returns(data.Provider);
mockDbSet.As<IQueryable<Entity>>().Setup(m => m.Expression).Returns(data.Expression);
mockDbSet.As<IQueryable<Entity>>().Setup(m => m.ElementType).Returns(data.ElementType);
mockDbSet.As<IQueryable<Entity>>().Setup(m => m.GetEnumerator()).Returns(data.GetEnumerator());
var mockCtx = new Mock<SomeDbContext>();
mockCtx.SetupGet(c => c.Entities).Returns(mockDbSet.Object);
var entities = await mockCtx.Object.Entities.ToListAsync();
Assert.NotNull(entities);
Assert.Equal(3, entities.Count());
}
You might be able to clean up those test implementations of the AsyncEnumerable and AsyncEnumerator even more. I didn't try, I just got it to work.
Remember your DbSet on your DbContext needs to be marked as virtual or else you will need to implement some interface wrapper over the DbContext to make this work properly.
I got help from Carsons answer, but I had to alter his code a bit to make it work with EntityFramework Core 6.4.4 and Moq.
Here is the altered code:
internal class AsyncEnumerable<T> : EnumerableQuery<T>, IAsyncEnumerable<T>, IQueryable<T>
{
public AsyncEnumerable(Expression expression)
: base(expression) { }
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default) =>
new AsyncEnumerator<T>(this.AsEnumerable().GetEnumerator());
}
internal class AsyncEnumerator<T> : IAsyncEnumerator<T>, IAsyncDisposable, IDisposable
{
private readonly IEnumerator<T> enumerator;
private Utf8JsonWriter? _jsonWriter = new(new MemoryStream());
public void Dispose()
{
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
public async ValueTask DisposeAsync()
{
await DisposeAsyncCore().ConfigureAwait(false);
Dispose(disposing: false);
#pragma warning disable CA1816 // Dispose methods should call SuppressFinalize
GC.SuppressFinalize(this);
#pragma warning restore CA1816 // Dispose methods should call SuppressFinalize
}
protected virtual void Dispose(bool disposing)
{
if (disposing)
{
_jsonWriter?.Dispose();
_jsonWriter = null;
}
}
protected virtual async ValueTask DisposeAsyncCore()
{
if (_jsonWriter is not null)
{
await _jsonWriter.DisposeAsync().ConfigureAwait(false);
}
_jsonWriter = null;
}
public AsyncEnumerator(IEnumerator<T> enumerator) =>
this.enumerator = enumerator ?? throw new ArgumentNullException();
public T Current => enumerator.Current;
public ValueTask<bool> MoveNextAsync() =>
new ValueTask<bool>(enumerator.MoveNext());
}
internal class TestAsyncQueryProvider<TEntity> : IDbAsyncQueryProvider
{
private readonly IQueryProvider _inner;
internal TestAsyncQueryProvider(IQueryProvider inner)
{
_inner = inner;
}
public IQueryable CreateQuery(Expression expression)
{
return new AsyncEnumerable<TEntity>(expression);
}
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
return new AsyncEnumerable<TElement>(expression);
}
public object Execute(Expression expression)
{
return _inner.Execute(expression);
}
public TResult Execute<TResult>(Expression expression)
{
return _inner.Execute<TResult>(expression);
}
public Task<object> ExecuteAsync(Expression expression, CancellationToken cancellationToken)
{
return Task.FromResult(Execute(expression));
}
public Task<TResult> ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
{
return Task.FromResult(Execute<TResult>(expression));
}
}
A helper class:
public static class MockDbSet
{
public static Mock<DbSet<TEntity>> BuildAsync<TEntity>(List<TEntity> data) where TEntity : class
{
var queryable = data.AsQueryable();
var mockSet = new Mock<DbSet<TEntity>>();
mockSet.As<IAsyncEnumerable<TEntity>>()
.Setup(d => d.GetAsyncEnumerator(It.IsAny<CancellationToken>()))
.Returns(new AsyncEnumerator<TEntity>(queryable.GetEnumerator()));
mockSet.As<IQueryable<TEntity>>()
.Setup(m => m.Provider)
.Returns(new TestAsyncQueryProvider<TEntity>(queryable.Provider));
mockSet.As<IQueryable<TEntity>>().Setup(m => m.Expression).Returns(queryable.Expression);
mockSet.As<IQueryable<TEntity>>().Setup(m => m.ElementType).Returns(queryable.ElementType);
mockSet.As<IQueryable<TEntity>>().Setup(m => m.GetEnumerator()).Returns(() => queryable.GetEnumerator());
mockSet.Setup(m => m.Add(It.IsAny<TEntity>())).Callback<TEntity>(data.Add);
return mockSet;
}
}
Mocking async unit test (MSTestV2):
[TestMethod]
public async Task GetData_Should_Not_Return_Null()
{
// Arrange
var data = new List<Entity>()
{
new Entity()
};
_mockContext.Setup(m => m.Entitys).Returns(MockDbSet.BuildAsync(data).Object);
// Act
var actual = await _repository.GetDataAsync();
// Assert
Assert.IsNotNull(actual);
}
See here to mock DbContext with async queryable support: https://stackoverflow.com/a/71076807/4905704
Like this:
[Fact]
public void Test()
{
var testData = new List<MyEntity>
{
new MyEntity() { Id = Guid.NewGuid() },
new MyEntity() { Id = Guid.NewGuid() },
new MyEntity() { Id = Guid.NewGuid() },
};
var mockDbContext = new MockDbContextAsynced<MyDbContext>();
mockDbContext.AddDbSetData<MyEntity>(testData.AsQueryable());
mockDbContext.MyEntities.ToArrayAsync();
// or
mockDbContext.MyEntities.SingleAsync();
// or etc.
// To inject MyDbContext as type parameter with mocked data
var mockService = new SomeService(mockDbContext.Object);
}
I have an Expression in the following form:
Expression<Func<T, bool>> predicate = t => t.Value == "SomeValue";
Is it possible to create a 'partially applied' version of this expression:
Expression<Func<bool>> predicate = () => t.Value == "SomeValue";
NB This expression is never actually compiled or invoked, it is merely inspected to generate some SQL.
This could be easily achieved by writing a custom ExpressionVisitor and replacing the parameter with a constant expression that you have captured in a closure:
public class Foo
{
public string Value { get; set; }
}
public class ReplaceVisitor<T> : ExpressionVisitor
{
private readonly T _instance;
public ReplaceVisitor(T instance)
{
_instance = instance;
}
protected override Expression VisitParameter(ParameterExpression node)
{
return Expression.Constant(_instance);
}
}
class Program
{
static void Main()
{
Expression<Func<Foo, bool>> predicate = t => t.Value == "SomeValue";
var foo = new Foo { Value = "SomeValue" };
Expression<Func<bool>> result = Convert(predicate, foo);
Console.WriteLine(result.Compile()());
}
static Expression<Func<bool>> Convert<T>(Expression<Func<T, bool>> expression, T instance)
{
return Expression.Lambda<Func<bool>>(
new ReplaceVisitor<T>(instance).Visit(expression.Body)
);
}
}
I think this should work:
Expression predicate2 = Expression.Invoke(predicate, Expression.Constant(new T() { Value = "SomeValue"}));
Expression<Func<bool>> predicate3 = Expression.Lambda<Func<bool>>(predicate2);
Don't know if this is easily parseable to generate SQL however (it works when compiled - I tried).
I want to rewrite certain parts of the LINQ expression just before execution. And I'm having problems injecting my rewriter in the correct place (at all actually).
Looking at the Entity Framework source (in reflector) it in the end comes down to the IQueryProvider.Execute which in EF is coupled to the expression by the ObjectContext offering the internal IQueryProvider Provider { get; } property.
So I created a a wrapper class (implementing IQueryProvider) to do the Expression rewriting when the Execute gets called and then pass it to the original Provider.
Problem is, the field behind Provider is private ObjectQueryProvider _queryProvider;. This ObjectQueryProvider is an internal sealed class, meaning it's not possible to create a subclass offering the added rewriting.
So this approach got me to a dead end due to the very tightly coupled ObjectContext.
How to solve this problem? Am I looking in the wrong direction? Is there perhaps a way to inject myself around this ObjectQueryProvider?
Update: While the provided solutions all work when you're "wrapping" the ObjectContext using the Repository pattern, a solution which would allow for direct usage of the generated subclass from ObjectContext would be preferable. Hereby remaining compatible with the Dynamic Data scaffolding.
Based on the answer by Arthur I've create a working wrapper.
The snippets provided provide a way to wrap each LINQ query with your own QueryProvider and IQueryable root. This would mean that you've got to have control over the initial query starting (as you'll have most of the time using any sort of pattern).
The problem with this method is that it's not transparent, a more ideal situation would be to inject something in the entities container at the constructor level.
I've created a compilable the implementation, got it to work with entity framework, and added support for the ObjectQuery.Include method. The expression visitor class can be copied from MSDN.
public class QueryTranslator<T> : IOrderedQueryable<T>
{
private Expression expression = null;
private QueryTranslatorProvider<T> provider = null;
public QueryTranslator(IQueryable source)
{
expression = Expression.Constant(this);
provider = new QueryTranslatorProvider<T>(source);
}
public QueryTranslator(IQueryable source, Expression e)
{
if (e == null) throw new ArgumentNullException("e");
expression = e;
provider = new QueryTranslatorProvider<T>(source);
}
public IEnumerator<T> GetEnumerator()
{
return ((IEnumerable<T>)provider.ExecuteEnumerable(this.expression)).GetEnumerator();
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return provider.ExecuteEnumerable(this.expression).GetEnumerator();
}
public QueryTranslator<T> Include(String path)
{
ObjectQuery<T> possibleObjectQuery = provider.source as ObjectQuery<T>;
if (possibleObjectQuery != null)
{
return new QueryTranslator<T>(possibleObjectQuery.Include(path));
}
else
{
throw new InvalidOperationException("The Include should only happen at the beginning of a LINQ expression");
}
}
public Type ElementType
{
get { return typeof(T); }
}
public Expression Expression
{
get { return expression; }
}
public IQueryProvider Provider
{
get { return provider; }
}
}
public class QueryTranslatorProvider<T> : ExpressionVisitor, IQueryProvider
{
internal IQueryable source;
public QueryTranslatorProvider(IQueryable source)
{
if (source == null) throw new ArgumentNullException("source");
this.source = source;
}
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
if (expression == null) throw new ArgumentNullException("expression");
return new QueryTranslator<TElement>(source, expression) as IQueryable<TElement>;
}
public IQueryable CreateQuery(Expression expression)
{
if (expression == null) throw new ArgumentNullException("expression");
Type elementType = expression.Type.GetGenericArguments().First();
IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType),
new object[] { source, expression });
return result;
}
public TResult Execute<TResult>(Expression expression)
{
if (expression == null) throw new ArgumentNullException("expression");
object result = (this as IQueryProvider).Execute(expression);
return (TResult)result;
}
public object Execute(Expression expression)
{
if (expression == null) throw new ArgumentNullException("expression");
Expression translated = this.Visit(expression);
return source.Provider.Execute(translated);
}
internal IEnumerable ExecuteEnumerable(Expression expression)
{
if (expression == null) throw new ArgumentNullException("expression");
Expression translated = this.Visit(expression);
return source.Provider.CreateQuery(translated);
}
#region Visitors
protected override Expression VisitConstant(ConstantExpression c)
{
// fix up the Expression tree to work with EF again
if (c.Type == typeof(QueryTranslator<T>))
{
return source.Expression;
}
else
{
return base.VisitConstant(c);
}
}
#endregion
}
Example usage in your repository:
public IQueryable<User> List()
{
return new QueryTranslator<User>(entities.Users).Include("Department");
}
I have exactly the sourcecode you'll need - but no idea how to attach a File.
Here are some snippets (snippets! I had to adapt this code, so it may not compile):
IQueryable:
public class QueryTranslator<T> : IOrderedQueryable<T>
{
private Expression _expression = null;
private QueryTranslatorProvider<T> _provider = null;
public QueryTranslator(IQueryable source)
{
_expression = Expression.Constant(this);
_provider = new QueryTranslatorProvider<T>(source);
}
public QueryTranslator(IQueryable source, Expression e)
{
if (e == null) throw new ArgumentNullException("e");
_expression = e;
_provider = new QueryTranslatorProvider<T>(source);
}
public IEnumerator<T> GetEnumerator()
{
return ((IEnumerable<T>)_provider.ExecuteEnumerable(this._expression)).GetEnumerator();
}
IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return _provider.ExecuteEnumerable(this._expression).GetEnumerator();
}
public Type ElementType
{
get { return typeof(T); }
}
public Expression Expression
{
get { return _expression; }
}
public IQueryProvider Provider
{
get { return _provider; }
}
}
IQueryProvider:
public class QueryTranslatorProvider<T> : ExpressionTreeTranslator, IQueryProvider
{
IQueryable _source;
public QueryTranslatorProvider(IQueryable source)
{
if (source == null) throw new ArgumentNullException("source");
_source = source;
}
#region IQueryProvider Members
public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
if (expression == null) throw new ArgumentNullException("expression");
return new QueryTranslator<TElement>(_source, expression) as IQueryable<TElement>;
}
public IQueryable CreateQuery(Expression expression)
{
if (expression == null) throw new ArgumentNullException("expression");
Type elementType = expression.Type.FindElementTypes().First();
IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType),
new object[] { _source, expression });
return result;
}
public TResult Execute<TResult>(Expression expression)
{
if (expression == null) throw new ArgumentNullException("expression");
object result = (this as IQueryProvider).Execute(expression);
return (TResult)result;
}
public object Execute(Expression expression)
{
if (expression == null) throw new ArgumentNullException("expression");
Expression translated = this.Visit(expression);
return _source.Provider.Execute(translated);
}
internal IEnumerable ExecuteEnumerable(Expression expression)
{
if (expression == null) throw new ArgumentNullException("expression");
Expression translated = this.Visit(expression);
return _source.Provider.CreateQuery(translated);
}
#endregion
#region Visits
protected override MethodCallExpression VisitMethodCall(MethodCallExpression m)
{
return m;
}
protected override Expression VisitUnary(UnaryExpression u)
{
return Expression.MakeUnary(u.NodeType, base.Visit(u.Operand), u.Type.ToImplementationType(), u.Method);
}
#endregion
}
Usage (warning: adapted code! May not compile):
private Dictionary<Type, object> _table = new Dictionary<Type, object>();
public override IQueryable<T> GetObjectQuery<T>()
{
if (!_table.ContainsKey(type))
{
_table[type] = new QueryTranslator<T>(
_ctx.CreateQuery<T>("[" + typeof(T).Name + "]"));
}
return (IQueryable<T>)_table[type];
}
Expression Visitors/Translator:
http://blogs.msdn.com/mattwar/archive/2007/07/31/linq-building-an-iqueryable-provider-part-ii.aspx
http://msdn.microsoft.com/en-us/library/bb882521.aspx
EDIT: Added FindElementTypes(). Hopefully all Methods are present now.
/// <summary>
/// Finds all implemented IEnumerables of the given Type
/// </summary>
public static IQueryable<Type> FindIEnumerables(this Type seqType)
{
if (seqType == null || seqType == typeof(object) || seqType == typeof(string))
return new Type[] { }.AsQueryable();
if (seqType.IsArray || seqType == typeof(IEnumerable))
return new Type[] { typeof(IEnumerable) }.AsQueryable();
if (seqType.IsGenericType && seqType.GetGenericArguments().Length == 1 && seqType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
{
return new Type[] { seqType, typeof(IEnumerable) }.AsQueryable();
}
var result = new List<Type>();
foreach (var iface in (seqType.GetInterfaces() ?? new Type[] { }))
{
result.AddRange(FindIEnumerables(iface));
}
return FindIEnumerables(seqType.BaseType).Union(result);
}
/// <summary>
/// Finds all element types provided by a specified sequence type.
/// "Element types" are T for IEnumerable<T> and object for IEnumerable.
/// </summary>
public static IQueryable<Type> FindElementTypes(this Type seqType)
{
return seqType.FindIEnumerables().Select(t => t.IsGenericType ? t.GetGenericArguments().Single() : typeof(object));
}
Just wanted to add to Arthur's example.
As Arthur warned there is a bug in his GetObjectQuery() method.
It creates the base Query using typeof(T).Name as the name of the EntitySet.
The EntitySet name is quite distinct from the type name.
If you are using EF 4 you should do this:
public override IQueryable<T> GetObjectQuery<T>()
{
if (!_table.ContainsKey(type))
{
_table[type] = new QueryTranslator<T>(
_ctx.CreateObjectSet<T>();
}
return (IQueryable<T>)_table[type];
}
Which works so long as you don't have Multiple Entity Sets per Type (MEST) which is very rare.
If you are using 3.5 you can use the code in my Tip 13 to get the EntitySet name and feed it in like this:
public override IQueryable<T> GetObjectQuery<T>()
{
if (!_table.ContainsKey(type))
{
_table[type] = new QueryTranslator<T>(
_ctx.CreateQuery<T>("[" + GetEntitySetName<T>() + "]"));
}
return (IQueryable<T>)_table[type];
}
Hope this helps
Alex
Entity Framework Tips