Look if a method is called inside a method using reflection - c#

I'm working with reflection and currently have a MethodBody. How do I check if a specific method is called inside the MethodBody?
Assembly assembly = Assembly.Load("Module1");
Type type = assembly.GetType("Module1.ModuleInit");
MethodInfo mi = type.GetMethod("Initialize");
MethodBody mb = mi.GetMethodBody();

Use Mono.Cecil. It is a single standalone assembly that will work on Microsoft .NET as well as Mono. (I think I used version 0.6 or thereabouts back when I wrote the code below)
Say you have a number of assemblies
IEnumerable<AssemblyDefinition> assemblies;
Get these using AssemblyFactory (load one?)
The following snippet would enumerate all usages of methods in all types of these assemblies
methodUsages = assemblies
.SelectMany(assembly => assembly.MainModule.Types.Cast<TypeDefinition>())
.SelectMany(type => type.Methods.Cast<MethodDefinition>())
.Where(method => null != method.Body) // allow abstracts and generics
.SelectMany(method => method.Body.Instructions.Cast<Instruction>())
.Select(instr => instr.Operand)
.OfType<MethodReference>();
This will return all references to methods (so including use in reflection, or to construct expressions which may or may not be executed). As such, this is probably not very useful, except to show you what can be done with the Cecil API without too much of an effort :)
Note that this sample assumes a somewhat older version of Cecil (the one in mainstream mono versions). Newer versions are
more succinct (by using strong typed generic collections)
faster
Of course in your case you could have a single method reference as starting point. Say you want to detect when 'mytargetmethod' can actually be called directly inside 'startingpoint':
MethodReference startingpoint; // get it somewhere using Cecil
MethodReference mytargetmethod; // what you are looking for
bool isCalled = startingpoint
.GetOriginalMethod() // jump to original (for generics e.g.)
.Resolve() // get the definition from the IL image
.Body.Instructions.Cast<Instruction>()
.Any(i => i.OpCode == OpCodes.Callvirt && i.Operand == (mytargetmethod));
Call Tree Search
Here is a working snippet that allows you to recursively search to (selected) methods that call each other (indirectly).
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Mono.Cecil;
using Mono.Cecil.Cil;
namespace StackOverflow
{
/*
* breadth-first lazy search across a subset of the call tree rooting in startingPoint
*
* methodSelect selects the methods to recurse into
* resultGen generates the result objects to be returned by the enumerator
*
*/
class CallTreeSearch<T> : BaseCodeVisitor, IEnumerable<T> where T : class
{
private readonly Func<MethodReference, bool> _methodSelect;
private readonly Func<Instruction, Stack<MethodReference>, T> _transform;
private readonly IEnumerable<MethodDefinition> _startingPoints;
private readonly IDictionary<MethodDefinition, Stack<MethodReference>> _chain = new Dictionary<MethodDefinition, Stack<MethodReference>>();
private readonly ICollection<MethodDefinition> _seen = new HashSet<MethodDefinition>(new CompareMembers<MethodDefinition>());
private readonly ICollection<T> _results = new HashSet<T>();
private Stack<MethodReference> _currentStack;
private const int InfiniteRecursion = -1;
private readonly int _maxrecursiondepth;
private bool _busy;
public CallTreeSearch(IEnumerable<MethodDefinition> startingPoints,
Func<MethodReference, bool> methodSelect,
Func<Instruction, Stack<MethodReference>, T> resultGen)
: this(startingPoints, methodSelect, resultGen, InfiniteRecursion)
{
}
public CallTreeSearch(IEnumerable<MethodDefinition> startingPoints,
Func<MethodReference, bool> methodSelect,
Func<Instruction, Stack<MethodReference>, T> resultGen,
int maxrecursiondepth)
{
_startingPoints = startingPoints.ToList();
_methodSelect = methodSelect;
_maxrecursiondepth = maxrecursiondepth;
_transform = resultGen;
}
public override void VisitMethodBody(MethodBody body)
{
_seen.Add(body.Method); // avoid infinite recursion
base.VisitMethodBody(body);
}
public override void VisitInstructionCollection(InstructionCollection instructions)
{
foreach (Instruction instr in instructions)
VisitInstruction(instr);
base.VisitInstructionCollection(instructions);
}
public override void VisitInstruction(Instruction instr)
{
T result = _transform(instr, _currentStack);
if (result != null)
_results.Add(result);
var methodRef = instr.Operand as MethodReference; // TODO select calls only?
if (methodRef != null && _methodSelect(methodRef))
{
var resolve = methodRef.Resolve();
if (null != resolve && !(_chain.ContainsKey(resolve) || _seen.Contains(resolve)))
_chain.Add(resolve, new Stack<MethodReference>(_currentStack.Reverse()));
}
base.VisitInstruction(instr);
}
public IEnumerator<T> GetEnumerator()
{
lock (this) // not multithread safe
{
if (_busy)
throw new InvalidOperationException("CallTreeSearch enumerator is not reentrant");
_busy = true;
try
{
int recursionLevel = 0;
ResetToStartingPoints();
while (_chain.Count > 0 &&
((InfiniteRecursion == _maxrecursiondepth) || recursionLevel++ <= _maxrecursiondepth))
{
// swapout the collection because Visitor will modify
var clone = new Dictionary<MethodDefinition, Stack<MethodReference>>(_chain);
_chain.Clear();
foreach (var call in clone.Where(call => HasBody(call.Key)))
{
// Console.Error.Write("\rCallTreeSearch: level #{0}, scanning {1,-20}\r", recursionLevel, call.Key.Name + new string(' ',21));
_currentStack = call.Value;
_currentStack.Push(call.Key);
try
{
_results.Clear();
call.Key.Body.Accept(this); // grows _chain and _results
}
finally
{
_currentStack.Pop();
}
_currentStack = null;
foreach (var result in _results)
yield return result;
}
}
}
finally
{
_busy = false;
}
}
}
private void ResetToStartingPoints()
{
_chain.Clear();
_seen.Clear();
foreach (var startingPoint in _startingPoints)
{
_chain.Add(startingPoint, new Stack<MethodReference>());
_seen.Add(startingPoint);
}
}
private static bool HasBody(MethodDefinition methodDefinition)
{
return !(methodDefinition.IsAbstract || methodDefinition.Body == null);
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
internal class CompareMembers<T> : IComparer<T>, IEqualityComparer<T>
where T: class, IMemberReference
{
public int Compare(T x, T y)
{ return StringComparer.InvariantCultureIgnoreCase.Compare(KeyFor(x), KeyFor(y)); }
public bool Equals(T x, T y)
{ return KeyFor(x).Equals(KeyFor(y)); }
private static string KeyFor(T mr)
{ return null == mr ? "" : String.Format("{0}::{1}", mr.DeclaringType.FullName, mr.Name); }
public int GetHashCode(T obj)
{ return KeyFor(obj).GetHashCode(); }
}
}
Notes
do some error handling a Resolve() (I have an extension method TryResolve() for the purpose)
optionally select usages of MethodReferences in a call operation (call, calli, callvirt ...) only (see //TODO)
Typical usage:
public static IEnumerable<T> SearchCallTree<T>(this TypeDefinition startingClass,
Func<MethodReference, bool> methodSelect,
Func<Instruction, Stack<MethodReference>, T> resultFunc,
int maxdepth)
where T : class
{
return new CallTreeSearch<T>(startingClass.Methods.Cast<MethodDefinition>(), methodSelect, resultFunc, maxdepth);
}
public static IEnumerable<T> SearchCallTree<T>(this MethodDefinition startingMethod,
Func<MethodReference, bool> methodSelect,
Func<Instruction, Stack<MethodReference>, T> resultFunc,
int maxdepth)
where T : class
{
return new CallTreeSearch<T>(new[] { startingMethod }, methodSelect, resultFunc, maxdepth);
}
// Actual usage:
private static IEnumerable<TypeUsage> SearchMessages(TypeDefinition uiType, bool onlyConstructions)
{
return uiType.SearchCallTree(IsBusinessCall,
(instruction, stack) => DetectRequestUsage(instruction, stack, onlyConstructions));
}
Note the completiion of a function like DetectRequestUsage to suite your needs is completely and entirely up to you (edit: but see here). You can do whatever you want, and don't forget: you'll have the complete statically analyzed call stack at your disposal, so you actually can do pretty neat things with all that information!

Before it generates code, it must check if it already exists
There are a few cases where catching an exception is way cheaper than preventing it from being generated. This is a prime example. You can get the IL for the method body but Reflection is not a disassembler. Nor is a disassembler a real fix, you'd have the disassemble the entire call tree to implement your desired behavior. After all, a method call in the body could itself call a method, etcetera. It is just much simpler to catch the exception that the jitter will throw when it compiles the IL.

One can use the StackTrace class:
System.Diagnostics.StackTrace st = new System.Diagnostics.StackTrace();
System.Diagnostics.StackFrame sf = st.GetFrame(1);
Console.Out.Write(sf.GetMethod().ReflectedType.Name + "." + sf.GetMethod().Name);
The 1 can be adjusted and determines the number of frame you are interested in.

Related

Using reflection to find an enumerator's method

Background:
I am modifying existing code using the Harmony Library. The existing C# code follows this structure:
public class ToModify
{
public override void Update()
{
foreach (StatusItemGroup.Entry entry in collection)
{
// I am trying to alter an operation at the end of this loop.
}
}
}
public class StatusItemGroup
{
public IEnumerator<Entry> GetEnumerator()
{
return items.GetEnumerator();
}
private List<Entry> items = new List<Entry>();
public struct Entry { }
}
Due to the situation, I must modify the IL code that is being generated, to do so I must obtain the MethodInfo of my target operand. This is the target:
IL_12B6: callvirt instance bool [mscorlib]System.Collections.IEnumerator::MoveNext()
Question:
How do I obtain the MethodInfo for the MoveNext method of an enumerator?
What I've tried:
Everything I can think of has yielded null results. This is my most basic attempt:
MethodInfo targetMethod = typeof(IEnumerator<StatusItemGroup.Entry>).GetMethod("MoveNext");
I don't understand why this doesn't work, and I don't know what I need to do to correctly obtain the MethodInfo.
MoveNext is not defined on IEnumerator<T>, but on the non-generic IEnumerator which is inherited by IEnumerator<T>.
Interface inheritance is a little weird in combination with reflection, so you need to obtain the method info directly from the base interface where it's defined:
MethodInfo targetMethod = typeof(System.Collections.IEnumerator).GetMethod("MoveNext");
Using the free LinqPad, I create this with Harmony 2.0 RC2. As you can see did I use a pass-through postfix to change the enumerator and wrap it. There are other ways and I suspect that you actually have an IEnumeration somewhere instead. That would be way easier to patch by using the pass-through postfix directly on the original method that returns the IEnumeration. No need to wrap the enumerator in that case.
But I don't know your full use case, so for now, this is the working example:
void Main()
{
var harmony = new Harmony("test");
harmony.PatchAll();
var group = new StatusItemGroup();
var items = new List<StatusItemGroup.Entry>() { StatusItemGroup.Entry.Make("A"), StatusItemGroup.Entry.Make("B") };
Traverse.Create(group).Field("items").SetValue(items);
var enumerator = group.GetEnumerator();
while(enumerator.MoveNext())
Console.WriteLine(enumerator.Current.id);
}
[HarmonyPatch]
class Patch
{
public class ProxyEnumerator<T> : IEnumerable<T>
{
public IEnumerator<T> enumerator;
public Func<T, T> transformer;
IEnumerator IEnumerable.GetEnumerator() { return GetEnumerator(); }
public IEnumerator<T> GetEnumerator()
{
while(enumerator.MoveNext())
yield return transformer(enumerator.Current);
}
}
[HarmonyPatch(typeof(StatusItemGroup), "GetEnumerator")]
static IEnumerator<StatusItemGroup.Entry> Postfix(IEnumerator<StatusItemGroup.Entry> enumerator)
{
StatusItemGroup.Entry Transform(StatusItemGroup.Entry entry)
{
entry.id += "+";
return entry;
}
var myEnumerator = new ProxyEnumerator<StatusItemGroup.Entry>()
{
enumerator = enumerator,
transformer = Transform
};
return myEnumerator.GetEnumerator();
}
}
public class StatusItemGroup
{
public IEnumerator<Entry> GetEnumerator()
{
return items.GetEnumerator();
}
private List<Entry> items = new List<Entry>();
public struct Entry
{
public string id;
public static Entry Make(string id) { return new Entry() { id = id }; }
}
}

Get lowered function name from lambda

Is it possible with Roslyn to get to the name of compiler generated lambda methods?
For example imagine the following class:
public sealed class Foo
{
public void Bar()
{
Func<int, int> func = x =>
{
if (x > 0)
{
return x;
}
return -x;
};
}
}
the following Code is generated:
public sealed class Foo
{
[CompilerGenerated]
[Serializable]
private sealed class <>c
{
public static readonly Foo.<>c <>9 = new Foo.<>c();
public static Func<int, int> <>9__0_0;
internal int <Bar>b__0_0(int x)
{
bool flag = x > 0;
int result;
if (flag)
{
result = x;
}
else
{
result = -x;
}
return result;
}
}
public void Bar()
{
Func<int, int> arg_20_0;
if ((arg_20_0 = Foo.<>c.<>9__0_0) == null)
{
Foo.<>c.<>9__0_0 = new Func<int, int>(Foo.<>c.<>9.<Bar>b__0_0);
}
}
}
Roslyn contains the code that is responsible for lowering the lambda into methods with different strategies depending on the circumstance (all here).
But is there any simple way to get the name Foo.<>c.<Bar>b__0_0 if I have the symbol or the SimpleLambdaExpressionSyntax node?
This is obviously implementation specific behavior, so this would require using the same roslyn version the compiler is using, but that'd be an acceptable.
It is not possible with the Roslyn APIs. You can use something like System.Reflection.Metadata to read the IL and find the names if you need to. However it should be said that the names the compiler generates are an implementation detail and they will change.

Ensure deferred execution will be executed only once or else

I ran into a weird issue and I'm wondering what I should do about it.
I have this class that return a IEnumerable<MyClass> and it is a deferred execution. Right now, there are two possible consumers. One of them sorts the result.
See the following example :
public class SomeClass
{
public IEnumerable<MyClass> GetMyStuff(Param givenParam)
{
double culmulativeSum = 0;
return myStuff.Where(...)
.OrderBy(...)
.TakeWhile( o =>
{
bool returnValue = culmulativeSum < givenParam.Maximum;
culmulativeSum += o.SomeNumericValue;
return returnValue;
};
}
}
Consumers call the deferred execution only once, but if they were to call it more than that, the result would be wrong as the culmulativeSum wouldn't be reset. I've found the issue by inadvertence with unit testing.
The easiest way for me to fix the issue would be to just add .ToArray() and get rid of the deferred execution at the cost of a little bit of overhead.
I could also add unit test in consumers class to ensure they call it only once, but that wouldn't prevent any new consumer coded in the future from this potential issue.
Another thing that came to my mind was to make subsequent execution throw.
Something like
return myStuff.Where(...)
.OrderBy(...)
.TakeWhile(...)
.ThrowIfExecutedMoreThan(1);
Obviously this doesn't exist.
Would it be a good idea to implement such thing and how would you do it?
Otherwise, if there is a big pink elephant that I don't see, pointing it out will be appreciated. (I feel there is one because this question is about a very basic scenario :| )
EDIT :
Here is a bad consumer usage example :
public class ConsumerClass
{
public void WhatEverMethod()
{
SomeClass some = new SomeClass();
var stuffs = some.GetMyStuff(param);
var nb = stuffs.Count(); //first deferred execution
var firstOne = stuff.First(); //second deferred execution with the culmulativeSum not reset
}
}
You can solve the incorrect result issue by simply turning your method into iterator:
double culmulativeSum = 0;
var query = myStuff.Where(...)
.OrderBy(...)
.TakeWhile(...);
foreach (var item in query) yield return item;
It can be encapsulated in a simple extension method:
public static class Iterators
{
public static IEnumerable<T> Lazy<T>(Func<IEnumerable<T>> source)
{
foreach (var item in source())
yield return item;
}
}
Then all you need to do in such scenarios is to surround the original method body with Iterators.Lazy call, e.g.:
return Iterators.Lazy(() =>
{
double culmulativeSum = 0;
return myStuff.Where(...)
.OrderBy(...)
.TakeWhile(...);
});
You can use the following class:
public class JustOnceOrElseEnumerable<T> : IEnumerable<T>
{
private readonly IEnumerable<T> decorated;
public JustOnceOrElseEnumerable(IEnumerable<T> decorated)
{
this.decorated = decorated;
}
private bool CalledAlready;
public IEnumerator<T> GetEnumerator()
{
if (CalledAlready)
throw new Exception("Enumerated already");
CalledAlready = true;
return decorated.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
if (CalledAlready)
throw new Exception("Enumerated already");
CalledAlready = true;
return decorated.GetEnumerator();
}
}
to decorate an enumerable so that it can only be enumerated once. After that it would throw an exception.
You can use this class like this:
return new JustOnceOrElseEnumerable(
myStuff.Where(...)
...
);
Please note that I do not recommend this approach because it violates the contract of the IEnumerable interface and thus the Liskov Substitution Principle. It is legal for consumers of this contract to assume that they can enumerate the enumerable as many times as they like.
Instead, you can use a cached enumerable that caches the result of enumeration. This ensures that the enumerable is only enumerated once and that all subsequent enumeration attempts would read from the cache. See this answer here for more information.
Ivan's answer is very fitting for the underlying issue in OP's example - but for the general case, I have approached this in the past using an extension method similar to the one below. This ensures that the Enumerable has a single evaluation but is also deferred:
public static IMemoizedEnumerable<T> Memoize<T>(this IEnumerable<T> source)
{
return new MemoizedEnumerable<T>(source);
}
private class MemoizedEnumerable<T> : IMemoizedEnumerable<T>, IDisposable
{
private readonly IEnumerator<T> _sourceEnumerator;
private readonly List<T> _cache = new List<T>();
public MemoizedEnumerable(IEnumerable<T> source)
{
_sourceEnumerator = source.GetEnumerator();
}
public IEnumerator<T> GetEnumerator()
{
return IsMaterialized ? _cache.GetEnumerator() : Enumerate();
}
private IEnumerator<T> Enumerate()
{
foreach (var value in _cache)
{
yield return value;
}
while (_sourceEnumerator.MoveNext())
{
_cache.Add(_sourceEnumerator.Current);
yield return _sourceEnumerator.Current;
}
_sourceEnumerator.Dispose();
IsMaterialized = true;
}
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
public List<T> Materialize()
{
if (IsMaterialized)
return _cache;
while (_sourceEnumerator.MoveNext())
{
_cache.Add(_sourceEnumerator.Current);
}
_sourceEnumerator.Dispose();
IsMaterialized = true;
return _cache;
}
public bool IsMaterialized { get; private set; }
void IDisposable.Dispose()
{
if(!IsMaterialized)
_sourceEnumerator.Dispose();
}
}
public interface IMemoizedEnumerable<T> : IEnumerable<T>
{
List<T> Materialize();
bool IsMaterialized { get; }
}
Example Usage:
void Consumer()
{
//var results = GetValuesComplex();
//var results = GetValuesComplex().ToList();
var results = GetValuesComplex().Memoize();
if(results.Any(i => i == 3))
{
Console.WriteLine("\nFirst Iteration");
//return; //Potential for early exit.
}
var last = results.Last(); // Causes multiple enumeration in naive case.
Console.WriteLine("\nSecond Iteration");
}
IEnumerable<int> GetValuesComplex()
{
for (int i = 0; i < 5; i++)
{
//... complex operations ...
Console.Write(i + ", ");
yield return i;
}
}
Naive: ✔ Deferred, ✘ Single enumeration.
ToList: ✘ Deferred, ✔ Single enumeration.
Memoize: ✔ Deferred, ✔ Single enumeration.
.
Edited to use the proper terminology and flesh out the implementation.

How to use the IEqualityComparer

I have some bells in my database with the same number. I want to get all of them without duplication. I created a compare class to do this work, but the execution of the function causes a big delay from the function without distinct, from 0.6 sec to 3.2 sec!
Am I doing it right or do I have to use another method?
reg.AddRange(
(from a in this.dataContext.reglements
join b in this.dataContext.Clients on a.Id_client equals b.Id
where a.date_v <= datefin && a.date_v >= datedeb
where a.Id_client == b.Id
orderby a.date_v descending
select new Class_reglement
{
nom = b.Nom,
code = b.code,
Numf = a.Numf,
})
.AsEnumerable()
.Distinct(new Compare())
.ToList());
class Compare : IEqualityComparer<Class_reglement>
{
public bool Equals(Class_reglement x, Class_reglement y)
{
if (x.Numf == y.Numf)
{
return true;
}
else { return false; }
}
public int GetHashCode(Class_reglement codeh)
{
return 0;
}
}
Your GetHashCode implementation always returns the same value. Distinct relies on a good hash function to work efficiently because it internally builds a hash table.
When implementing interfaces of classes it is important to read the documentation, to know which contract you’re supposed to implement.1
In your code, the solution is to forward GetHashCode to Class_reglement.Numf.GetHashCode and implement it appropriately there.
Apart from that, your Equals method is full of unnecessary code. It could be rewritten as follows (same semantics, ¼ of the code, more readable):
public bool Equals(Class_reglement x, Class_reglement y)
{
return x.Numf == y.Numf;
}
Lastly, the ToList call is unnecessary and time-consuming: AddRange accepts any IEnumerable so conversion to a List isn’t required. AsEnumerable is also redundant here since processing the result in AddRange will cause this anyway.
1 Writing code without knowing what it actually does is called cargo cult programming. It’s a surprisingly widespread practice. It fundamentally doesn’t work.
Try This code:
public class GenericCompare<T> : IEqualityComparer<T> where T : class
{
private Func<T, object> _expr { get; set; }
public GenericCompare(Func<T, object> expr)
{
this._expr = expr;
}
public bool Equals(T x, T y)
{
var first = _expr.Invoke(x);
var sec = _expr.Invoke(y);
if (first != null && first.Equals(sec))
return true;
else
return false;
}
public int GetHashCode(T obj)
{
return obj.GetHashCode();
}
}
Example of its use would be
collection = collection
.Except(ExistedDataEles, new GenericCompare<DataEle>(x=>x.Id))
.ToList();
If you want a generic solution that creates an IEqualityComparer for your class based on a property (which acts as a key) of that class have a look at this:
public class KeyBasedEqualityComparer<T, TKey> : IEqualityComparer<T>
{
private readonly Func<T, TKey> _keyGetter;
public KeyBasedEqualityComparer(Func<T, TKey> keyGetter)
{
if (default(T) == null)
{
_keyGetter = (x) => x == null ? default : keyGetter(x);
}
else
{
_keyGetter = keyGetter;
}
}
public bool Equals(T x, T y)
{
return EqualityComparer<TKey>.Default.Equals(_keyGetter(x), _keyGetter(y));
}
public int GetHashCode(T obj)
{
TKey key = _keyGetter(obj);
return key == null ? 0 : key.GetHashCode();
}
}
public static class KeyBasedEqualityComparer<T>
{
public static KeyBasedEqualityComparer<T, TKey> Create<TKey>(Func<T, TKey> keyGetter)
{
return new KeyBasedEqualityComparer<T, TKey>(keyGetter);
}
}
For better performance with structs there isn't any boxing.
Usage is like this:
IEqualityComparer<Class_reglement> equalityComparer =
KeyBasedEqualityComparer<Class_reglement>.Create(x => x.Numf);
Just code, with implementation of GetHashCode and NULL validation:
public class Class_reglementComparer : IEqualityComparer<Class_reglement>
{
public bool Equals(Class_reglement x, Class_reglement y)
{
if (x is null || y is null))
return false;
return x.Numf == y.Numf;
}
public int GetHashCode(Class_reglement product)
{
//Check whether the object is null
if (product is null) return 0;
//Get hash code for the Numf field if it is not null.
int hashNumf = product.hashNumf == null ? 0 : product.hashNumf.GetHashCode();
return hashNumf;
}
}
Example:
list of Class_reglement distinct by Numf
List<Class_reglement> items = items.Distinct(new Class_reglementComparer());
The purpose of this answer is to improve on previous answers by:
making the lambda expression optional in the constructor so that full object equality can be checked by default, not just on one of the properties.
operating on different types of classes, even complex types including sub-objects or nested lists. And not only on simple classes comprising only primitive type properties.
Not taking into account possible list container differences.
Here, you'll find a first simple code sample that works only on simple types (the ones composed only by primitif properties), and a second one that is complete (for a wider range of classes and complex types).
Here is my 2 pennies try:
public class GenericEqualityComparer<T> : IEqualityComparer<T> where T : class
{
private Func<T, object> _expr { get; set; }
public GenericEqualityComparer() => _expr = null;
public GenericEqualityComparer(Func<T, object> expr) => _expr = expr;
public bool Equals(T x, T y)
{
var first = _expr?.Invoke(x) ?? x;
var sec = _expr?.Invoke(y) ?? y;
if (first == null && sec == null)
return true;
if (first != null && first.Equals(sec))
return true;
var typeProperties = typeof(T).GetProperties();
foreach (var prop in typeProperties)
{
var firstPropVal = prop.GetValue(first, null);
var secPropVal = prop.GetValue(sec, null);
if (firstPropVal != null && !firstPropVal.Equals(secPropVal))
return false;
}
return true;
}
public int GetHashCode(T obj) =>
_expr?.Invoke(obj).GetHashCode() ?? obj.GetHashCode();
}
I know we can still optimize it (and maybe use a recursive?)..
But that is working like a charm without this much complexity and on a wide range of classes. ;)
Edit: After a day, here is my $10 attempt:
First, in a separate static extension class, you'll need:
public static class CollectionExtensions
{
public static bool HasSameLengthThan<T>(this IEnumerable<T> list, IEnumerable<T> expected)
{
if (list.IsNullOrEmptyCollection() && expected.IsNullOrEmptyCollection())
return true;
if ((list.IsNullOrEmptyCollection() && !expected.IsNullOrEmptyCollection()) || (!list.IsNullOrEmptyCollection() && expected.IsNullOrEmptyCollection()))
return false;
return list.Count() == expected.Count();
}
/// <summary>
/// Used to find out if a collection is empty or if it contains no elements.
/// </summary>
/// <typeparam name="T">Type of the collection's items.</typeparam>
/// <param name="list">Collection of items to test.</param>
/// <returns><c>true</c> if the collection is <c>null</c> or empty (without items), <c>false</c> otherwise.</returns>
public static bool IsNullOrEmptyCollection<T>(this IEnumerable<T> list) => list == null || !list.Any();
}
Then, here is the updated class that works on a wider range of classes:
public class GenericComparer<T> : IEqualityComparer<T> where T : class
{
private Func<T, object> _expr { get; set; }
public GenericComparer() => _expr = null;
public GenericComparer(Func<T, object> expr) => _expr = expr;
public bool Equals(T x, T y)
{
var first = _expr?.Invoke(x) ?? x;
var sec = _expr?.Invoke(y) ?? y;
if (ObjEquals(first, sec))
return true;
var typeProperties = typeof(T).GetProperties();
foreach (var prop in typeProperties)
{
var firstPropVal = prop.GetValue(first, null);
var secPropVal = prop.GetValue(sec, null);
if (!ObjEquals(firstPropVal, secPropVal))
{
var propType = prop.PropertyType;
if (IsEnumerableType(propType) && firstPropVal is IEnumerable && !ArrayEquals(firstPropVal, secPropVal))
return false;
if (propType.IsClass)
{
if (!DeepEqualsFromObj(firstPropVal, secPropVal, propType))
return false;
if (!DeepObjEquals(firstPropVal, secPropVal))
return false;
}
}
}
return true;
}
public int GetHashCode(T obj) =>
_expr?.Invoke(obj).GetHashCode() ?? obj.GetHashCode();
#region Private Helpers
private bool DeepObjEquals(object x, object y) =>
new GenericComparer<object>().Equals(x, y);
private bool DeepEquals<U>(U x, U y) where U : class =>
new GenericComparer<U>().Equals(x, y);
private bool DeepEqualsFromObj(object x, object y, Type type)
{
dynamic a = Convert.ChangeType(x, type);
dynamic b = Convert.ChangeType(y, type);
return DeepEquals(a, b);
}
private bool IsEnumerableType(Type type) =>
type.GetInterface(nameof(IEnumerable)) != null;
private bool ObjEquals(object x, object y)
{
if (x == null && y == null) return true;
return x != null && x.Equals(y);
}
private bool ArrayEquals(object x, object y)
{
var firstList = new List<object>((IEnumerable<object>)x);
var secList = new List<object>((IEnumerable<object>)y);
if (!firstList.HasSameLengthThan(secList))
return false;
var elementType = firstList?.FirstOrDefault()?.GetType();
int cpt = 0;
foreach (var e in firstList)
{
if (!DeepEqualsFromObj(e, secList[cpt++], elementType))
return false;
}
return true;
}
#endregion Private Helpers
We can still optimize it but it worth give it a try ^^.
The inclusion of your comparison class (or more specifically the AsEnumerable call you needed to use to get it to work) meant that the sorting logic went from being based on the database server to being on the database client (your application). This meant that your client now needs to retrieve and then process a larger number of records, which will always be less efficient that performing the lookup on the database where the approprate indexes can be used.
You should try to develop a where clause that satisfies your requirements instead, see Using an IEqualityComparer with a LINQ to Entities Except clause for more details.
IEquatable<T> can be a much easier way to do this with modern frameworks.
You get a nice simple bool Equals(T other) function and there's no messing around with casting or creating a separate class.
public class Person : IEquatable<Person>
{
public Person(string name, string hometown)
{
this.Name = name;
this.Hometown = hometown;
}
public string Name { get; set; }
public string Hometown { get; set; }
// can't get much simpler than this!
public bool Equals(Person other)
{
return this.Name == other.Name && this.Hometown == other.Hometown;
}
public override int GetHashCode()
{
return Name.GetHashCode(); // see other links for hashcode guidance
}
}
Note you DO have to implement GetHashCode if using this in a dictionary or with something like Distinct.
PS. I don't think any custom Equals methods work with entity framework directly on the database side (I think you know this because you do AsEnumerable) but this is a much simpler method to do a simple Equals for the general case.
If things don't seem to be working (such as duplicate key errors when doing ToDictionary) put a breakpoint inside Equals to make sure it's being hit and make sure you have GetHashCode defined (with override keyword).

How do you deal with sequences of IDisposable using LINQ?

What's the best approach to call Dispose() on the elements of a sequence?
Suppose there's something like:
IEnumerable<string> locations = ...
var streams = locations.Select ( a => new FileStream ( a , FileMode.Open ) );
var notEmptyStreams = streams.Where ( a => a.Length > 0 );
//from this point on only `notEmptyStreams` will be used/visible
var firstBytes = notEmptyStreams.Select ( a => a.ReadByte () );
var average = firstBytes.Average ();
How do you dispose FileStream instances (as soon as they're no longer needed) while maintaining concise code?
To clarify: this is not an actual piece of code, those lines are methods across a set of classes, and FileStream type is also just an example.
Is doing something along the lines of:
public static IEnumerable<TSource> Where<TSource> (
this IEnumerable<TSource> source ,
Func<TSource , bool> predicate
)
where TSource : IDisposable {
foreach ( var item in source ) {
if ( predicate ( item ) ) {
yield return item;
}
else {
item.Dispose ();
}
}
}
might be a good idea?
Alternatively: do you always solve a very specific scenario with regards to IEnumerable<IDisposable> without trying to generalize? Is it so because having it is an untypical situation? Do you design around having it in the first place? If so, how?
I would write a method, say, AsDisposableCollection that returns a wrapped IEnumerable which also implements IDisposable, so that you can use the usual using pattern. This is a bit more work (implementation of the method), but you need that only once and then you can use the method nicely (as often as you need):
using(var streams = locations.Select(a => new FileStream(a, FileMode.Open))
.AsDisposableCollection()) {
// ...
}
The implementation would look roughly like this (it is not complete - just to show the idea):
class DisposableCollection<T> : IDisposable, IEnumerable<T>
where T : IDisposable {
IEnumerable<T> en; // Wrapped enumerable
List<T> garbage; // To keep generated objects
public DisposableCollection(IEnumerable<T> en) {
this.en = en;
this.garbage = new List<T>();
}
// Enumerates over all the elements and stores generated
// elements in a list of garbage (to be disposed)
public IEnumerator<T> GetEnumerator() {
foreach(var o in en) {
garbage.Add(o);
yield return o;
}
}
// Dispose all elements that were generated so far...
public Dispose() {
foreach(var o in garbage) o.Dispose();
}
}
I suggest you turn the streams variable into an Array or a List, because enumerating through it a second time will (if I'm not mistaken) create new copies of the streams.
var streams = locations.Select(a => new FileStream(a, FileMode.Open)).ToList();
// dispose right away of those you won't need
foreach (FileStream stream in streams.Where(a => a.Length == 0))
stream.Dispose();
var notEmptyStreams = streams.Where(a => a.Length > 0);
// the rest of your code here
foreach (FileStream stream in notEmptyStreams)
stream.Dispose();
EDIT For these constraints, maybe LINQ isn't the best tool around. Maybe you could get away with a simple foreach loop?
var streams = locations.Select(a => new FileStream(a, FileMode.Open));
int count = 0;
int sum = 0;
foreach (FileStream stream in streams) using (stream)
{
if (stream.Length == 0) continue;
count++;
sum += stream.ReadByte();
}
int average = sum / count;
A simple solution is the following:
List<Stream> streams = locations
.Select(a => new FileStream(a, FileMode.Open))
.ToList();
try
{
// Use the streams.
}
finally
{
foreach (IDisposable stream in streams)
stream.Dispose();
}
Note that even with this you could in theory still fail to close a stream if one of the FileStream constructors fails after others have already been constructed. To fix that you need to be more careful constructing the inital list:
List<Stream> streams = new List<Stream>();
try
{
foreach (string location in locations)
{
streams.Add(new FileStream(location, FileMode.Open));
}
// Use the streams.
}
finally { /* same as before */ }
It's a lot of code and it's not concise like you wanted but if you want to be sure that all your streams are being closed, even when there are exceptions, then you should do this.
If you want something more LINQ-like you might want to read this article by Marc Gravell:
SelectMany; combining IDisposable and LINQ
Description
I've come up with a general solution to the problem :)
One of the things that were important to me was that everything is correctly disposed, even if I don't iterate the whole enumeration, which is the case when I use methods like FirstOrDefault (which I do quite often).
So I came up with a custom Enumerator which handles all the disposing. All you have to do is call the AsDisposeableEnumerable which does all the magic for you.
GetMy.Disposeables()
.AsDisposeableEnumerable() // <-- all the magic is injected here
.Skip(5)
.where(i => i > 1024)
.Select(i => new {myNumber = i})
.FirstOrDefault()
Please note that this does not work for infinite enumerations.
The Code
My custom IEnumerable
public class DisposeableEnumerable<T> : IEnumerable<T> where T : System.IDisposable
{
private readonly IEnumerable<T> _enumerable;
public DisposeableEnumerable(IEnumerable<T> enumerable)
{
_enumerable = enumerable;
}
public IEnumerator<T> GetEnumerator()
{
return new DisposeableEnumerator<T>(_enumerable.GetEnumerator());
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
My custom IEnumerator
public class DisposeableEnumerator<T> : IEnumerator<T> where T : System.IDisposable
{
readonly List<T> toBeDisposed = new List<T>();
private readonly IEnumerator<T> _enumerator;
public DisposeableEnumerator(IEnumerator<T> enumerator)
{
_enumerator = enumerator;
}
public void Dispose()
{
// dispose the remaining disposeables
while (_enumerator.MoveNext()) {
T current = _enumerator.Current;
current.Dispose();
}
// dispose the provided disposeables
foreach (T disposeable in toBeDisposed) {
disposeable.Dispose();
}
// dispose the internal enumerator
_enumerator.Dispose();
}
public bool MoveNext()
{
bool result = _enumerator.MoveNext();
if (result) {
toBeDisposed.Add(_enumerator.Current);
}
return result;
}
public void Reset()
{
_enumerator.Reset();
}
public T Current
{
get
{
return _enumerator.Current;
}
}
object IEnumerator.Current
{
get { return Current; }
}
}
A fancy extension method to make things look good
public static class IDisposeableEnumerableExtensions
{
/// <summary>
/// Wraps the given IEnumarable into a DisposeableEnumerable which ensures that all the disposeables are disposed correctly
/// </summary>
/// <typeparam name="T">The IDisposeable type</typeparam>
/// <param name="enumerable">The enumerable to ensure disposing the elements of</param>
/// <returns></returns>
public static DisposeableEnumerable<T> AsDisposeableEnumerable<T>(this IEnumerable<T> enumerable) where T : System.IDisposable
{
return new DisposeableEnumerable<T>(enumerable);
}
}
Using code from https://lostechies.com/keithdahlby/2009/07/23/using-idisposables-with-linq/, you can turn your query into the following:
(
from location in locations
from stream in new FileStream(location, FileMode.Open).Use()
where stream.Length > 0
select stream.ReadByte()).Average()
You will need the following extension method:
public static IEnumerable<T> Use<T>(this T obj) where T : IDisposable
{
try
{
yield return obj;
}
finally
{
if (obj != null)
obj.Dispose();
}
}
This will properly dispose all the streams that you create, whether or not they are empty.
Here is a simple wrapper that allows you to dispose any IEnumerable with using (to retain the collection type rather than casting as IEnumerable, we would need nested generic parameter types which C# does not seem to support):
public static class DisposableEnumerableExtensions {
public static DisposableEnumerable<T> AsDisposable<T>(this IEnumerable<T> enumerable) where T : IDisposable {
return new DisposableEnumerable<T>(enumerable);
}
}
public class DisposableEnumerable<T> : IDisposable where T : IDisposable {
public IEnumerable<T> Enumerable { get; }
public DisposableEnumerable(IEnumerable<T> enumerable) {
this.Enumerable = enumerable;
}
public void Dispose() {
foreach (var o in this.Enumerable) o.Dispose();
}
}
Usage:
using (var processes = System.Diagnostics.Process.GetProcesses().AsDisposable()) {
foreach (var p in processes.Enumerable) {
Console.Write(p.Id);
}
}

Categories