There are numerous information that static checking of Contract.ForAll has only limited or no support.
I did lot of experimenting and found it can work with:
Contract.ForAll(items, i => i != null)
Contract.ForAll(items, p) where p is of type Predicate<T>
it cannot work with:
Field access
Property access
Method group (I think delegate is allocated here anyway)
Instance method call
My questions are:
What are other types of code that ForAll can work with?
Does the Code Contracts undertand that after Contract.ForAll(items, i => i != null) is proven, that when taking one item from the list later in code (i.e. by indexing), the item is not null?
Here is full test code:
public sealed class Test
{
public bool Field;
public static Predicate<Test> Predicate;
[Pure]
public bool Property
{
get { return Field; }
}
[Pure]
public static bool Method(Test t)
{
return t.Field;
}
[Pure]
public bool InstanceMethod()
{
return Field;
}
public static void Test1()
{
var items = new List<Test>();
Contract.Assume(Contract.ForAll(items, i => i != null));
Contract.Assert(Contract.ForAll(items, i => i != null)); // OK
}
public static void Test2()
{
var items = new List<Test>();
Contract.Assume(Contract.ForAll(items, Predicate));
Contract.Assert(Contract.ForAll(items, Predicate)); // OK
}
public static void Test3()
{
var items = new List<Test>();
Contract.Assume(Contract.ForAll(items, i => i.Field));
Contract.Assert(Contract.ForAll(items, i => i.Field)); // assert unproven
}
public static void Test4()
{
var items = new List<Test>();
Contract.Assume(Contract.ForAll(items, i => i.Property));
Contract.Assert(Contract.ForAll(items, i => i.Property)); // assert unproven
}
public static void Test5()
{
var items = new List<Test>();
Contract.Assume(Contract.ForAll(items, Method));
Contract.Assert(Contract.ForAll(items, Method)); // assert unproven
}
public static void Test6()
{
var items = new List<Test>();
Contract.Assume(Contract.ForAll(items, i => i.InstanceMethod()));
Contract.Assert(Contract.ForAll(items, i => i.InstanceMethod()));// assert unproven
}
}
I was not able to find more working expressions, in fact I found that even Contract.ForAll(items, i => i != null) is not working reliably (but it understands that the item is not null when later used inside foreach in the same function). Finally, I gave up on possibility to use more complex ForAll contracts with static checker.
Instead I devised a way to control which contract are for static checker, and which are for runtime checker. I present this method here, in hope that it might be useful for people interesting in original question. The benefit is ability to be write more complex contracts, which can be checked at runtime only, and leave only easily provable contracts for static checker (and easily keep warnings at low count).
For that, 2 builds Debug are needed (If you don't already have them), Debug and Debug + Static Contracts, The Debug build has conditional compilation symbol MYPROJECT_CONTRACTS_RUNTIME defined. In this way it receives all Contract. and RtContract. contracts. Other builds receive only Contract. contracts.
public static class RtContract
{
[Pure] [ContractAbbreviator] [Conditional("MYPROJECT_CONTRACTS_RUNTIME")]
public static void Requires(bool condition)
{
Contract.Requires(condition);
}
[Pure] [ContractAbbreviator] [Conditional("MYPROJECT_CONTRACTS_RUNTIME")]
public static void Ensures(bool condition)
{
Contract.Ensures(condition);
}
[Pure] [Conditional("MYPROJECT_CONTRACTS_RUNTIME")]
public static void Assume(bool condition)
{
Contract.Assume(condition);
}
}
public class Usage
{
void Test (int x)
{
Contract.Requires(x >= 0); // for static and runtime
RtContract.Requires(x.IsFibonacci()); // for runtime only
}
}
By decompiling mscorelib.dll System.Diagnostics.Contracts we can easely see how Contract.ForAll is built: It takes collection and predicate.
public static bool ForAll<T>(IEnumerable<T> collection, Predicate<T> predicate)
{
if (collection == null)
{
throw new ArgumentNullException("collection");
}
if (predicate == null)
{
throw new ArgumentNullException("predicate");
}
foreach (T current in collection)
{
if (!predicate(current))
{
return false;
}
}
return true;
}
So when you say Contract.ForAll(items, i => i.Field) in this case i => i.Field is predicate
Then by following your example in all test methods, we can see that you provide an empty list to Contract.ForAll method which will return true as it will never enter the foreach block.
Taking it further, if you add item to your list
var items = new List<Test>() {new Test()}; and run the test again it will return false as your public bool Field; is false by default
The goal of Contract.ForAll is to
Determines whether all the elements in a collection exist within a
function
So my conclusion is that it is not about Contarct.ForAll can't work with something, it is rather at least one element in your collection returns false or is null
Related
I ran into a weird issue and I'm wondering what I should do about it.
I have this class that return a IEnumerable<MyClass> and it is a deferred execution. Right now, there are two possible consumers. One of them sorts the result.
See the following example :
public class SomeClass
{
public IEnumerable<MyClass> GetMyStuff(Param givenParam)
{
double culmulativeSum = 0;
return myStuff.Where(...)
.OrderBy(...)
.TakeWhile( o =>
{
bool returnValue = culmulativeSum < givenParam.Maximum;
culmulativeSum += o.SomeNumericValue;
return returnValue;
};
}
}
Consumers call the deferred execution only once, but if they were to call it more than that, the result would be wrong as the culmulativeSum wouldn't be reset. I've found the issue by inadvertence with unit testing.
The easiest way for me to fix the issue would be to just add .ToArray() and get rid of the deferred execution at the cost of a little bit of overhead.
I could also add unit test in consumers class to ensure they call it only once, but that wouldn't prevent any new consumer coded in the future from this potential issue.
Another thing that came to my mind was to make subsequent execution throw.
Something like
return myStuff.Where(...)
.OrderBy(...)
.TakeWhile(...)
.ThrowIfExecutedMoreThan(1);
Obviously this doesn't exist.
Would it be a good idea to implement such thing and how would you do it?
Otherwise, if there is a big pink elephant that I don't see, pointing it out will be appreciated. (I feel there is one because this question is about a very basic scenario :| )
EDIT :
Here is a bad consumer usage example :
public class ConsumerClass
{
public void WhatEverMethod()
{
SomeClass some = new SomeClass();
var stuffs = some.GetMyStuff(param);
var nb = stuffs.Count(); //first deferred execution
var firstOne = stuff.First(); //second deferred execution with the culmulativeSum not reset
}
}
You can solve the incorrect result issue by simply turning your method into iterator:
double culmulativeSum = 0;
var query = myStuff.Where(...)
.OrderBy(...)
.TakeWhile(...);
foreach (var item in query) yield return item;
It can be encapsulated in a simple extension method:
public static class Iterators
{
public static IEnumerable<T> Lazy<T>(Func<IEnumerable<T>> source)
{
foreach (var item in source())
yield return item;
}
}
Then all you need to do in such scenarios is to surround the original method body with Iterators.Lazy call, e.g.:
return Iterators.Lazy(() =>
{
double culmulativeSum = 0;
return myStuff.Where(...)
.OrderBy(...)
.TakeWhile(...);
});
You can use the following class:
public class JustOnceOrElseEnumerable<T> : IEnumerable<T>
{
private readonly IEnumerable<T> decorated;
public JustOnceOrElseEnumerable(IEnumerable<T> decorated)
{
this.decorated = decorated;
}
private bool CalledAlready;
public IEnumerator<T> GetEnumerator()
{
if (CalledAlready)
throw new Exception("Enumerated already");
CalledAlready = true;
return decorated.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
if (CalledAlready)
throw new Exception("Enumerated already");
CalledAlready = true;
return decorated.GetEnumerator();
}
}
to decorate an enumerable so that it can only be enumerated once. After that it would throw an exception.
You can use this class like this:
return new JustOnceOrElseEnumerable(
myStuff.Where(...)
...
);
Please note that I do not recommend this approach because it violates the contract of the IEnumerable interface and thus the Liskov Substitution Principle. It is legal for consumers of this contract to assume that they can enumerate the enumerable as many times as they like.
Instead, you can use a cached enumerable that caches the result of enumeration. This ensures that the enumerable is only enumerated once and that all subsequent enumeration attempts would read from the cache. See this answer here for more information.
Ivan's answer is very fitting for the underlying issue in OP's example - but for the general case, I have approached this in the past using an extension method similar to the one below. This ensures that the Enumerable has a single evaluation but is also deferred:
public static IMemoizedEnumerable<T> Memoize<T>(this IEnumerable<T> source)
{
return new MemoizedEnumerable<T>(source);
}
private class MemoizedEnumerable<T> : IMemoizedEnumerable<T>, IDisposable
{
private readonly IEnumerator<T> _sourceEnumerator;
private readonly List<T> _cache = new List<T>();
public MemoizedEnumerable(IEnumerable<T> source)
{
_sourceEnumerator = source.GetEnumerator();
}
public IEnumerator<T> GetEnumerator()
{
return IsMaterialized ? _cache.GetEnumerator() : Enumerate();
}
private IEnumerator<T> Enumerate()
{
foreach (var value in _cache)
{
yield return value;
}
while (_sourceEnumerator.MoveNext())
{
_cache.Add(_sourceEnumerator.Current);
yield return _sourceEnumerator.Current;
}
_sourceEnumerator.Dispose();
IsMaterialized = true;
}
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
public List<T> Materialize()
{
if (IsMaterialized)
return _cache;
while (_sourceEnumerator.MoveNext())
{
_cache.Add(_sourceEnumerator.Current);
}
_sourceEnumerator.Dispose();
IsMaterialized = true;
return _cache;
}
public bool IsMaterialized { get; private set; }
void IDisposable.Dispose()
{
if(!IsMaterialized)
_sourceEnumerator.Dispose();
}
}
public interface IMemoizedEnumerable<T> : IEnumerable<T>
{
List<T> Materialize();
bool IsMaterialized { get; }
}
Example Usage:
void Consumer()
{
//var results = GetValuesComplex();
//var results = GetValuesComplex().ToList();
var results = GetValuesComplex().Memoize();
if(results.Any(i => i == 3))
{
Console.WriteLine("\nFirst Iteration");
//return; //Potential for early exit.
}
var last = results.Last(); // Causes multiple enumeration in naive case.
Console.WriteLine("\nSecond Iteration");
}
IEnumerable<int> GetValuesComplex()
{
for (int i = 0; i < 5; i++)
{
//... complex operations ...
Console.Write(i + ", ");
yield return i;
}
}
Naive: ✔ Deferred, ✘ Single enumeration.
ToList: ✘ Deferred, ✔ Single enumeration.
Memoize: ✔ Deferred, ✔ Single enumeration.
.
Edited to use the proper terminology and flesh out the implementation.
I am currently trying to learn how to use unit testing, and I have created the actual list of 3 animal objects and the expected list of 3 animal objects. The question is how do I Assert to check the lists are equal? I have tried CollectionAssert.AreEqual and Assert.AreEqual but to no avail. Any help would be appreciated.
The test method:
[TestMethod]
public void createAnimalsTest2()
{
animalHandler animalHandler = new animalHandler();
// arrange
List<Animal> expected = new List<Animal>();
Animal dog = new Dog("",0);
Animal cat = new Cat("",0);
Animal mouse = new Mouse("",0);
expected.Add(dog);
expected.Add(cat);
expected.Add(mouse);
//actual
List<Animal> actual = animalHandler.createAnimals("","","",0,0,0);
//assert
//this is the line that does not evaluate as true
Assert.Equals(expected ,actual);
}
That is correct, as the lists might look the same, they are 2 different objects containing the same data.
In order to compare lists, you should use the CollectionAssert
CollectionAssert.AreEqual(expected, actual);
That should do the trick.
Just incase someone comes across this in the future, the answer was I had to create an Override, IEqualityComparer as described below:
public class MyPersonEqualityComparer : IEqualityComparer<MyPerson>
{
public bool Equals(MyPerson x, MyPerson y)
{
if (object.ReferenceEquals(x, y)) return true;
if (object.ReferenceEquals(x, null)||object.ReferenceEquals(y, null)) return false;
return x.Name == y.Name && x.Age == y.Age;
}
public int GetHashCode(MyPerson obj)
{
if (object.ReferenceEquals(obj, null)) return 0;
int hashCodeName = obj.Name == null ? 0 : obj.Name.GetHashCode();
int hasCodeAge = obj.Age.GetHashCode();
return hashCodeName ^ hasCodeAge;
}
}
I am of the opinion that implementing the IEqualityComparer (Equals() and GetHashCode()) for only testing purpose is a code smell. I would rather use the following assertion method, where you can freely define that on which properties you want to do the assertions:
public static void AssertListEquals<TE, TA>(Action<TE, TA> asserter, IEnumerable<TE> expected, IEnumerable<TA> actual)
{
IList<TA> actualList = actual.ToList();
IList<TE> expectedList = expected.ToList();
Assert.True(
actualList.Count == expectedList.Count,
$"Lists have different sizes. Expected list: {expectedList.Count}, actual list: {actualList.Count}");
for (var i = 0; i < expectedList.Count; i++)
{
try
{
asserter.Invoke(expectedList[i], actualList[i]);
}
catch (Exception e)
{
Assert.True(false, $"Assertion failed because: {e.Message}");
}
}
}
In action it would look like as follows:
public void TestMethod()
{
//Arrange
//...
//Act
//...
//Assert
AssertAnimals(expectedAnimals, actualAnimals);
}
private void AssertAnimals(IEnumerable<Animal> expectedAnimals, IEnumerable<Animal> actualAnimals)
{
ListAsserter.AssertListEquals(
(e,a) => AssertAnimal(e,a),
expectedAnimals,
actualAnimals);
}
private void AssertAnimal(Animal expected, Animal actual)
{
Assert.Equal(expected.Name, actual.Name);
Assert.Equal(expected.Weight, actual.Weight);
//Additional properties to assert...
}
I am using XUnit for the simple Assert.True(...) and Assert.Equals(), but you can use any other unit test library for that. Hope it helps someone! ;)
I modified method AssertListEquals() and used standard Assert.All()
public static void AssertListEquals<TE, TA>(IEnumerable<TE> expected, IEnumerable<TA> actual, Action<TE, TA> asserter)
{
if (expected == null && actual == null) return;
Assert.NotNull(expected);
Assert.NotNull(actual);
Assert.True(
actual.Count() == expected.Count(),
$"Lists have different sizes. Expected list: {expected.Count()}, actual list: {actual.Count()}");
var i = 0;
Assert.All(expected, e =>
{
try
{
asserter(e, actual.Skip(i).First());
}
finally
{
i++;
}
});
}
I have a Item[] _items array of items, where some of the items may be null. I wish to check if the array contains at least one non-null item.
My current implementations seems a little complicated:
internal bool IsEmtpy { get { return (!(this.NotEmpty)); } }
private bool IsNotEmpty { get { return ( this.Items.Any(t => t != null));} }
So my question is: Is there a simpler way to check if a typed array of reference objects contains at least one non null object?
There is no complexity in your implementation. Basically, the only way to check whether there are non-null values in the array is to look through all values until you will reach non-null value or the end of the array.
The following code is easier to understand though:
internal bool IsEmtpy { get { return this.Items.All(t => t == null); } }
private bool IsNotEmpty { get { return this.Items.Any(t => t != null); } }
And it is probably better to extend IEnumerable as follows:
public static class Extensions {
public static bool ContainsOnlyEmpty<TSource>(this IEnumerable<TSource> source) {
return source.All(t => t == null);
}
public static bool ContainsNonEmpty<TSource>(this IEnumerable<TSource> source) {
return source.Any(t => t != null);
}
}
and use it like this: bool nonEmpty = this.Items.ContainsNonEmpty();
I'm trying to get a deeper understanding of Monads. Therefore I started digging a little into the Maybe Monad.
There is one thing that I just don't seem to get right. Read this:
"So the Maybe Bind acts a short circuit. In any chain of operations, if any one of them returns Nothing, the evaluation will cease and Nothing will be returned from the entire chain."
From: http://mikehadlow.blogspot.com/2011/01/monads-in-c-5-maybe.html
And this:
"For the Maybe<T> type, binding is implemented according to as simple rule: if chain returns an empty value at some point, further steps in the chain are ignored and an empty value is returend instead"
From: "Functional Programming in C#" http://www.amazon.com/Functional-Programming-Techniques-Projects-Programmer/dp/0470744588/
Ok, let's look at the code. Here is my Maybe Monad:
public class Maybe<T>
{
public static readonly Maybe<T> Empty = new Maybe<T>();
public Maybe(T value)
{
Value = value;
}
private Maybe()
{
}
public bool HasValue()
{
return !EqualityComparer<T>.Default.Equals(Value, default(T));
}
public T Value { get; private set; }
public Maybe<R> Bind<R>(Func<T, Maybe<R>> apply)
{
return HasValue() ? apply(Value) : Maybe<R>.Empty;
}
}
public static class MaybeExtensions
{
public static Maybe<T> ToMaybe<T>(this T value)
{
return new Maybe<T>(value);
}
}
And here is my example code using the monad:
class Program
{
static void Main(string[] args)
{
var node = new Node("1", new Node("2", new Node("3", new Node("4", null))));
var childNode = node.ChildNode
.ToMaybe()
.Bind(x => x.ChildNode.ToMaybe())
.Bind(x => x.ChildNode.ToMaybe())
.Bind(x => x.ChildNode.ToMaybe())
.Bind(x => x.ChildNode.ToMaybe())
.Bind(x => x.ChildNode.ToMaybe());
Console.WriteLine(childNode.HasValue() ? childNode.Value.Value : "");
Console.ReadLine();
}
}
public class Node
{
public Node(string value, Node childNode)
{
Value = value;
ChildNode = childNode;
}
public string Value { get; set; }
public Node ChildNode { get; private set; }
}
It's clear to see that we are trying to dig deeper into the node tree than possible. However, I fail to see how it is acting according to the quotes I mentioned. I mean, of course I have factored out the null checks and the example works. However, it doesn't break the chain early. If you set breakpoints you will see that every Bind() operation will be used thus without a value for the last operations. But it means, if I dig 20 level deep and it actually only goes down 3 levels I still will check 20 levels or am I wrong?
Compare this to the non-monad approach:
if (node.ChildNode != null
&& node.ChildNode.ChildNode != null
&& node.ChildNode.ChildNode.ChildNode != null)
{
Console.WriteLine(node.ChildNode.ChildNode.ChildNode.Value);
}
Isn't this actually what should be called a short circuit? Because in this case the if really breaks at the level where the first value is null.
Can anybody help me to get this clear?
UPDATE
As Patrik pointed out, yes it is true each bind will be invoked even if we only have 3 levels and try to go 20 levels deep. However, the actual expression provided to the Bind() call won't be evaluated. We can edit the example to make the effect clear:
var childNode = node.ChildNode
.ToMaybe()
.Bind(x =>
{
Console.WriteLine("We will see this");
return x.ChildNode.ToMaybe();
})
.Bind(x => x.ChildNode.ToMaybe())
.Bind(x => x.ChildNode.ToMaybe())
.Bind(x => x.ChildNode.ToMaybe())
.Bind(x =>
{
Console.WriteLine("We won't see this");
return x.ChildNode.ToMaybe();
});
I have an implementation of the maybe monad in c# that differs a little from yours, first of all it's not tied to null checks, I believe my implementation more closesly resembles what happens in a standard maybe implementation in for example Haskel.
My implementation:
public abstract class Maybe<T>
{
public static readonly Maybe<T> Nothing = new NothingMaybe();
public static Maybe<T> Just(T value)
{
return new JustMaybe(value);
}
public abstract Maybe<T2> Bind<T2>(Func<T, Maybe<T2>> binder);
private class JustMaybe
: Maybe<T>
{
readonly T value;
public JustMaybe(T value)
{
this.value = value;
}
public override Maybe<T2> Bind<T2>(Func<T, Maybe<T2>> binder)
{
return binder(this.value);
}
}
private class NothingMaybe
: Maybe<T>
{
public override Maybe<T2> Bind<T2>(Func<T, Maybe<T2>> binder)
{
return Maybe<T2>.Nothing;
}
}
}
As you see here the bind function of the NothingMaybe just returns a new nothing so passed in binder expression is never evaluated. It's short circuiting in the sense that no more binder expressions will be evaluated once you got into the "nothing state", however the Bind-function itself will be invoked for each monad in the chain.
This implementation of maybe could be used for any type of "uncertain operation", for example a null check or checking for an empty string, this way all those different types of operations can be chained together:
public static class Maybe
{
public static Maybe<T> NotNull<T>(T value) where T : class
{
return value != null ? Maybe<T>.Just(value) : Maybe<T>.Nothing;
}
public static Maybe<string> NotEmpty(string value)
{
return value.Length != 0 ? Maybe<string>.Just(value) : Maybe<string>.Nothing;
}
}
string foo = "whatever";
Maybe.NotNull(foo).Bind(x => Maybe.NotEmpty(x)).Bind(x => { Console.WriteLine(x); return Maybe<string>.Just(x); });
This would print "whatever" to the console, however if the value was null or empty it would do nothing.
As I understand it, all Bind methods will be invoked, but the provided expressions will be evaluated only if the previous one returns a value. This means that Bind methods that are invoked after one that returns null (or more correctly: default(T)) will be very cheap.
We can do this more cunningly.
Write interface derived from IEnumerable
public interface IOptional<T>: IEnumerable<T> {}
This will save compatibility with LINQ methods
public class Maybe<T>: IOptional<T>
{
private readonly IEnumerable<T> _element;
public Maybe(T element)
: this(new T[1] { element })
{}
public Maybe()
: this(new T[0])
{}
private Maybe(T[] element)
{
_element = element;
}
public IEnumerator<T> GetEnumerator()
{
return _element.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
After this we can use full power of LINQ and do like this
var node = new Node("1", new Node("2", new Node("3", new Node("4", null))));
var childNode =
new Some<Node>(node.ChildNode)
.SelectMany(n => new Maybe<Node>(n.ChildNode))
.SelectMany(n => new Maybe<Node>(n.ChildNode))
.SelectMany(n => new Maybe<Node>(n.ChildNode))
.SelectMany(n => new Maybe<Node>(n.ChildNode))
.SelectMany(n => new Maybe<Node>(n.ChildNode));
Console.WriteLine(childNode.Any() ? childNode.First().Value : "");
I'm working with reflection and currently have a MethodBody. How do I check if a specific method is called inside the MethodBody?
Assembly assembly = Assembly.Load("Module1");
Type type = assembly.GetType("Module1.ModuleInit");
MethodInfo mi = type.GetMethod("Initialize");
MethodBody mb = mi.GetMethodBody();
Use Mono.Cecil. It is a single standalone assembly that will work on Microsoft .NET as well as Mono. (I think I used version 0.6 or thereabouts back when I wrote the code below)
Say you have a number of assemblies
IEnumerable<AssemblyDefinition> assemblies;
Get these using AssemblyFactory (load one?)
The following snippet would enumerate all usages of methods in all types of these assemblies
methodUsages = assemblies
.SelectMany(assembly => assembly.MainModule.Types.Cast<TypeDefinition>())
.SelectMany(type => type.Methods.Cast<MethodDefinition>())
.Where(method => null != method.Body) // allow abstracts and generics
.SelectMany(method => method.Body.Instructions.Cast<Instruction>())
.Select(instr => instr.Operand)
.OfType<MethodReference>();
This will return all references to methods (so including use in reflection, or to construct expressions which may or may not be executed). As such, this is probably not very useful, except to show you what can be done with the Cecil API without too much of an effort :)
Note that this sample assumes a somewhat older version of Cecil (the one in mainstream mono versions). Newer versions are
more succinct (by using strong typed generic collections)
faster
Of course in your case you could have a single method reference as starting point. Say you want to detect when 'mytargetmethod' can actually be called directly inside 'startingpoint':
MethodReference startingpoint; // get it somewhere using Cecil
MethodReference mytargetmethod; // what you are looking for
bool isCalled = startingpoint
.GetOriginalMethod() // jump to original (for generics e.g.)
.Resolve() // get the definition from the IL image
.Body.Instructions.Cast<Instruction>()
.Any(i => i.OpCode == OpCodes.Callvirt && i.Operand == (mytargetmethod));
Call Tree Search
Here is a working snippet that allows you to recursively search to (selected) methods that call each other (indirectly).
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Mono.Cecil;
using Mono.Cecil.Cil;
namespace StackOverflow
{
/*
* breadth-first lazy search across a subset of the call tree rooting in startingPoint
*
* methodSelect selects the methods to recurse into
* resultGen generates the result objects to be returned by the enumerator
*
*/
class CallTreeSearch<T> : BaseCodeVisitor, IEnumerable<T> where T : class
{
private readonly Func<MethodReference, bool> _methodSelect;
private readonly Func<Instruction, Stack<MethodReference>, T> _transform;
private readonly IEnumerable<MethodDefinition> _startingPoints;
private readonly IDictionary<MethodDefinition, Stack<MethodReference>> _chain = new Dictionary<MethodDefinition, Stack<MethodReference>>();
private readonly ICollection<MethodDefinition> _seen = new HashSet<MethodDefinition>(new CompareMembers<MethodDefinition>());
private readonly ICollection<T> _results = new HashSet<T>();
private Stack<MethodReference> _currentStack;
private const int InfiniteRecursion = -1;
private readonly int _maxrecursiondepth;
private bool _busy;
public CallTreeSearch(IEnumerable<MethodDefinition> startingPoints,
Func<MethodReference, bool> methodSelect,
Func<Instruction, Stack<MethodReference>, T> resultGen)
: this(startingPoints, methodSelect, resultGen, InfiniteRecursion)
{
}
public CallTreeSearch(IEnumerable<MethodDefinition> startingPoints,
Func<MethodReference, bool> methodSelect,
Func<Instruction, Stack<MethodReference>, T> resultGen,
int maxrecursiondepth)
{
_startingPoints = startingPoints.ToList();
_methodSelect = methodSelect;
_maxrecursiondepth = maxrecursiondepth;
_transform = resultGen;
}
public override void VisitMethodBody(MethodBody body)
{
_seen.Add(body.Method); // avoid infinite recursion
base.VisitMethodBody(body);
}
public override void VisitInstructionCollection(InstructionCollection instructions)
{
foreach (Instruction instr in instructions)
VisitInstruction(instr);
base.VisitInstructionCollection(instructions);
}
public override void VisitInstruction(Instruction instr)
{
T result = _transform(instr, _currentStack);
if (result != null)
_results.Add(result);
var methodRef = instr.Operand as MethodReference; // TODO select calls only?
if (methodRef != null && _methodSelect(methodRef))
{
var resolve = methodRef.Resolve();
if (null != resolve && !(_chain.ContainsKey(resolve) || _seen.Contains(resolve)))
_chain.Add(resolve, new Stack<MethodReference>(_currentStack.Reverse()));
}
base.VisitInstruction(instr);
}
public IEnumerator<T> GetEnumerator()
{
lock (this) // not multithread safe
{
if (_busy)
throw new InvalidOperationException("CallTreeSearch enumerator is not reentrant");
_busy = true;
try
{
int recursionLevel = 0;
ResetToStartingPoints();
while (_chain.Count > 0 &&
((InfiniteRecursion == _maxrecursiondepth) || recursionLevel++ <= _maxrecursiondepth))
{
// swapout the collection because Visitor will modify
var clone = new Dictionary<MethodDefinition, Stack<MethodReference>>(_chain);
_chain.Clear();
foreach (var call in clone.Where(call => HasBody(call.Key)))
{
// Console.Error.Write("\rCallTreeSearch: level #{0}, scanning {1,-20}\r", recursionLevel, call.Key.Name + new string(' ',21));
_currentStack = call.Value;
_currentStack.Push(call.Key);
try
{
_results.Clear();
call.Key.Body.Accept(this); // grows _chain and _results
}
finally
{
_currentStack.Pop();
}
_currentStack = null;
foreach (var result in _results)
yield return result;
}
}
}
finally
{
_busy = false;
}
}
}
private void ResetToStartingPoints()
{
_chain.Clear();
_seen.Clear();
foreach (var startingPoint in _startingPoints)
{
_chain.Add(startingPoint, new Stack<MethodReference>());
_seen.Add(startingPoint);
}
}
private static bool HasBody(MethodDefinition methodDefinition)
{
return !(methodDefinition.IsAbstract || methodDefinition.Body == null);
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
internal class CompareMembers<T> : IComparer<T>, IEqualityComparer<T>
where T: class, IMemberReference
{
public int Compare(T x, T y)
{ return StringComparer.InvariantCultureIgnoreCase.Compare(KeyFor(x), KeyFor(y)); }
public bool Equals(T x, T y)
{ return KeyFor(x).Equals(KeyFor(y)); }
private static string KeyFor(T mr)
{ return null == mr ? "" : String.Format("{0}::{1}", mr.DeclaringType.FullName, mr.Name); }
public int GetHashCode(T obj)
{ return KeyFor(obj).GetHashCode(); }
}
}
Notes
do some error handling a Resolve() (I have an extension method TryResolve() for the purpose)
optionally select usages of MethodReferences in a call operation (call, calli, callvirt ...) only (see //TODO)
Typical usage:
public static IEnumerable<T> SearchCallTree<T>(this TypeDefinition startingClass,
Func<MethodReference, bool> methodSelect,
Func<Instruction, Stack<MethodReference>, T> resultFunc,
int maxdepth)
where T : class
{
return new CallTreeSearch<T>(startingClass.Methods.Cast<MethodDefinition>(), methodSelect, resultFunc, maxdepth);
}
public static IEnumerable<T> SearchCallTree<T>(this MethodDefinition startingMethod,
Func<MethodReference, bool> methodSelect,
Func<Instruction, Stack<MethodReference>, T> resultFunc,
int maxdepth)
where T : class
{
return new CallTreeSearch<T>(new[] { startingMethod }, methodSelect, resultFunc, maxdepth);
}
// Actual usage:
private static IEnumerable<TypeUsage> SearchMessages(TypeDefinition uiType, bool onlyConstructions)
{
return uiType.SearchCallTree(IsBusinessCall,
(instruction, stack) => DetectRequestUsage(instruction, stack, onlyConstructions));
}
Note the completiion of a function like DetectRequestUsage to suite your needs is completely and entirely up to you (edit: but see here). You can do whatever you want, and don't forget: you'll have the complete statically analyzed call stack at your disposal, so you actually can do pretty neat things with all that information!
Before it generates code, it must check if it already exists
There are a few cases where catching an exception is way cheaper than preventing it from being generated. This is a prime example. You can get the IL for the method body but Reflection is not a disassembler. Nor is a disassembler a real fix, you'd have the disassemble the entire call tree to implement your desired behavior. After all, a method call in the body could itself call a method, etcetera. It is just much simpler to catch the exception that the jitter will throw when it compiles the IL.
One can use the StackTrace class:
System.Diagnostics.StackTrace st = new System.Diagnostics.StackTrace();
System.Diagnostics.StackFrame sf = st.GetFrame(1);
Console.Out.Write(sf.GetMethod().ReflectedType.Name + "." + sf.GetMethod().Name);
The 1 can be adjusted and determines the number of frame you are interested in.