I need to search a tree for data that could be anywhere in the tree. How can this be done with linq?
class Program
{
static void Main(string[] args) {
var familyRoot = new Family() {Name = "FamilyRoot"};
var familyB = new Family() {Name = "FamilyB"};
familyRoot.Children.Add(familyB);
var familyC = new Family() {Name = "FamilyC"};
familyB.Children.Add(familyC);
var familyD = new Family() {Name = "FamilyD"};
familyC.Children.Add(familyD);
//There can be from 1 to n levels of families.
//Search all children, grandchildren, great grandchildren etc, for "FamilyD" and return the object.
}
}
public class Family {
public string Name { get; set; }
List<Family> _children = new List<Family>();
public List<Family> Children {
get { return _children; }
}
}
That's an extension to It'sNotALie.s answer.
public static class Linq
{
public static IEnumerable<T> Flatten<T>(this T source, Func<T, IEnumerable<T>> selector)
{
return selector(source).SelectMany(c => Flatten(c, selector))
.Concat(new[] { source });
}
}
Sample test usage:
var result = familyRoot.Flatten(x => x.Children).FirstOrDefault(x => x.Name == "FamilyD");
Returns familyD object.
You can make it work on IEnumerable<T> source too:
public static IEnumerable<T> Flatten<T>(this IEnumerable<T> source, Func<T, IEnumerable<T>> selector)
{
return source.SelectMany(x => Flatten(x, selector))
.Concat(source);
}
Another solution without recursion...
var result = FamilyToEnumerable(familyRoot)
.Where(f => f.Name == "FamilyD");
IEnumerable<Family> FamilyToEnumerable(Family f)
{
Stack<Family> stack = new Stack<Family>();
stack.Push(f);
while (stack.Count > 0)
{
var family = stack.Pop();
yield return family;
foreach (var child in family.Children)
stack.Push(child);
}
}
Simple:
familyRoot.Flatten(f => f.Children);
//you can do whatever you want with that sequence there.
//for example you could use Where on it and find the specific families, etc.
IEnumerable<T> Flatten<T>(this T source, Func<T, IEnumerable<T>> selector)
{
return selector(source).SelectMany(c => Flatten(selector(c), selector))
.Concat(new[]{source});
}
So, the simplest option is to write a function that traverses your hierarchy and produces a single sequence. This then goes at the start of your LINQ operations, e.g.
IEnumerable<T> Flatten<T>(this T source)
{
foreach(var item in source) {
yield item;
foreach(var child in Flatten(item.Children)
yield child;
}
}
To call simply: familyRoot.Flatten().Where(n => n.Name == "Bob");
A slight alternative would give you a way to quickly ignore a whole branch:
IEnumerable<T> Flatten<T>(this T source, Func<T, bool> predicate)
{
foreach(var item in source) {
if (predicate(item)) {
yield item;
foreach(var child in Flatten(item.Children)
yield child;
}
}
Then you could do things like: family.Flatten(n => n.Children.Count > 2).Where(...)
I like Kenneth Bo Christensen's answer using stack, it works great, it is easy to read and it is fast (and doesn't use recursion).
The only unpleasant thing is that it reverses the order of child items (because stack is FIFO). If sort order doesn't matter to you then it's ok.
If it does, sorting can be achieved easily using selector(current).Reverse() in the foreach loop (the rest of the code is the same as in Kenneth's original post)...
public static IEnumerable<T> Flatten<T>(this T source, Func<T, IEnumerable<T>> selector)
{
var stack = new Stack<T>();
stack.Push(source);
while (stack.Count > 0)
{
var current = stack.Pop();
yield return current;
foreach (var child in selector(current).Reverse())
stack.Push(child);
}
}
Well, I guess the way is to go with the technique of working with hierarchical structures:
You need an anchor to make
You need the recursion part
// Anchor
rootFamily.Children.ForEach(childFamily =>
{
if (childFamily.Name.Contains(search))
{
// Your logic here
return;
}
SearchForChildren(childFamily);
});
// Recursion
public void SearchForChildren(Family childFamily)
{
childFamily.Children.ForEach(_childFamily =>
{
if (_childFamily.Name.Contains(search))
{
// Your logic here
return;
}
SearchForChildren(_childFamily);
});
}
I have tried two of the suggested codes and made the code a bit more clear:
public static IEnumerable<T> Flatten1<T>(this T source, Func<T, IEnumerable<T>> selector)
{
return selector(source).SelectMany(c => Flatten1(c, selector)).Concat(new[] { source });
}
public static IEnumerable<T> Flatten2<T>(this T source, Func<T, IEnumerable<T>> selector)
{
var stack = new Stack<T>();
stack.Push(source);
while (stack.Count > 0)
{
var current = stack.Pop();
yield return current;
foreach (var child in selector(current))
stack.Push(child);
}
}
Flatten2() seems to be a little bit faster but its a close run.
Some further variants on the answers of It'sNotALie., MarcinJuraszek and DamienG.
First, the former two give a counterintuitive ordering. To get a nice tree-traversal ordering to the results, just invert the concatenation (put the "source" first).
Second, if you are working with an expensive source like EF, and you want to limit entire branches, Damien's suggestion that you inject the predicate is a good one and can still be done with Linq.
Finally, for an expensive source it may also be good to pre-select the fields of interest from each node with an injected selector.
Putting all these together:
public static IEnumerable<R> Flatten<T,R>(this T source, Func<T, IEnumerable<T>> children
, Func<T, R> selector
, Func<T, bool> branchpredicate = null
) {
if (children == null) throw new ArgumentNullException("children");
if (selector == null) throw new ArgumentNullException("selector");
var pred = branchpredicate ?? (src => true);
if (children(source) == null) return new[] { selector(source) };
return new[] { selector(source) }
.Concat(children(source)
.Where(pred)
.SelectMany(c => Flatten(c, children, selector, pred)));
}
Related
I need to search a tree for data that could be anywhere in the tree. How can this be done with linq?
class Program
{
static void Main(string[] args) {
var familyRoot = new Family() {Name = "FamilyRoot"};
var familyB = new Family() {Name = "FamilyB"};
familyRoot.Children.Add(familyB);
var familyC = new Family() {Name = "FamilyC"};
familyB.Children.Add(familyC);
var familyD = new Family() {Name = "FamilyD"};
familyC.Children.Add(familyD);
//There can be from 1 to n levels of families.
//Search all children, grandchildren, great grandchildren etc, for "FamilyD" and return the object.
}
}
public class Family {
public string Name { get; set; }
List<Family> _children = new List<Family>();
public List<Family> Children {
get { return _children; }
}
}
That's an extension to It'sNotALie.s answer.
public static class Linq
{
public static IEnumerable<T> Flatten<T>(this T source, Func<T, IEnumerable<T>> selector)
{
return selector(source).SelectMany(c => Flatten(c, selector))
.Concat(new[] { source });
}
}
Sample test usage:
var result = familyRoot.Flatten(x => x.Children).FirstOrDefault(x => x.Name == "FamilyD");
Returns familyD object.
You can make it work on IEnumerable<T> source too:
public static IEnumerable<T> Flatten<T>(this IEnumerable<T> source, Func<T, IEnumerable<T>> selector)
{
return source.SelectMany(x => Flatten(x, selector))
.Concat(source);
}
Another solution without recursion...
var result = FamilyToEnumerable(familyRoot)
.Where(f => f.Name == "FamilyD");
IEnumerable<Family> FamilyToEnumerable(Family f)
{
Stack<Family> stack = new Stack<Family>();
stack.Push(f);
while (stack.Count > 0)
{
var family = stack.Pop();
yield return family;
foreach (var child in family.Children)
stack.Push(child);
}
}
Simple:
familyRoot.Flatten(f => f.Children);
//you can do whatever you want with that sequence there.
//for example you could use Where on it and find the specific families, etc.
IEnumerable<T> Flatten<T>(this T source, Func<T, IEnumerable<T>> selector)
{
return selector(source).SelectMany(c => Flatten(selector(c), selector))
.Concat(new[]{source});
}
So, the simplest option is to write a function that traverses your hierarchy and produces a single sequence. This then goes at the start of your LINQ operations, e.g.
IEnumerable<T> Flatten<T>(this T source)
{
foreach(var item in source) {
yield item;
foreach(var child in Flatten(item.Children)
yield child;
}
}
To call simply: familyRoot.Flatten().Where(n => n.Name == "Bob");
A slight alternative would give you a way to quickly ignore a whole branch:
IEnumerable<T> Flatten<T>(this T source, Func<T, bool> predicate)
{
foreach(var item in source) {
if (predicate(item)) {
yield item;
foreach(var child in Flatten(item.Children)
yield child;
}
}
Then you could do things like: family.Flatten(n => n.Children.Count > 2).Where(...)
I like Kenneth Bo Christensen's answer using stack, it works great, it is easy to read and it is fast (and doesn't use recursion).
The only unpleasant thing is that it reverses the order of child items (because stack is FIFO). If sort order doesn't matter to you then it's ok.
If it does, sorting can be achieved easily using selector(current).Reverse() in the foreach loop (the rest of the code is the same as in Kenneth's original post)...
public static IEnumerable<T> Flatten<T>(this T source, Func<T, IEnumerable<T>> selector)
{
var stack = new Stack<T>();
stack.Push(source);
while (stack.Count > 0)
{
var current = stack.Pop();
yield return current;
foreach (var child in selector(current).Reverse())
stack.Push(child);
}
}
Well, I guess the way is to go with the technique of working with hierarchical structures:
You need an anchor to make
You need the recursion part
// Anchor
rootFamily.Children.ForEach(childFamily =>
{
if (childFamily.Name.Contains(search))
{
// Your logic here
return;
}
SearchForChildren(childFamily);
});
// Recursion
public void SearchForChildren(Family childFamily)
{
childFamily.Children.ForEach(_childFamily =>
{
if (_childFamily.Name.Contains(search))
{
// Your logic here
return;
}
SearchForChildren(_childFamily);
});
}
I have tried two of the suggested codes and made the code a bit more clear:
public static IEnumerable<T> Flatten1<T>(this T source, Func<T, IEnumerable<T>> selector)
{
return selector(source).SelectMany(c => Flatten1(c, selector)).Concat(new[] { source });
}
public static IEnumerable<T> Flatten2<T>(this T source, Func<T, IEnumerable<T>> selector)
{
var stack = new Stack<T>();
stack.Push(source);
while (stack.Count > 0)
{
var current = stack.Pop();
yield return current;
foreach (var child in selector(current))
stack.Push(child);
}
}
Flatten2() seems to be a little bit faster but its a close run.
Some further variants on the answers of It'sNotALie., MarcinJuraszek and DamienG.
First, the former two give a counterintuitive ordering. To get a nice tree-traversal ordering to the results, just invert the concatenation (put the "source" first).
Second, if you are working with an expensive source like EF, and you want to limit entire branches, Damien's suggestion that you inject the predicate is a good one and can still be done with Linq.
Finally, for an expensive source it may also be good to pre-select the fields of interest from each node with an injected selector.
Putting all these together:
public static IEnumerable<R> Flatten<T,R>(this T source, Func<T, IEnumerable<T>> children
, Func<T, R> selector
, Func<T, bool> branchpredicate = null
) {
if (children == null) throw new ArgumentNullException("children");
if (selector == null) throw new ArgumentNullException("selector");
var pred = branchpredicate ?? (src => true);
if (children(source) == null) return new[] { selector(source) };
return new[] { selector(source) }
.Concat(children(source)
.Where(pred)
.SelectMany(c => Flatten(c, children, selector, pred)));
}
I'm trying to write an extension method that is supposed to traverse an object graph and return all visited objects.
I'm not sure if my approach is the best, so please do comment on that. Also yield is frying my brain... I'm sure the answer is obvious :/
Model
public class MyClass
{
public MyClass Parent {get;set;}
}
Method
public static IEnumerable<T> SelectNested<T>
(this T source, Func<T, T> selector)
where T : class
{
yield return source;
var parent = selector(source);
if (parent == null)
yield break;
yield return SelectNestedParents(parent, selector).FirstOrDefault();
}
Usage
var list = myObject.SelectNested(x => x.Parent);
The problem
It's almost working. But it only visits 2 objects. It self and the parent.
So given this graph c -> b -> a starting from c. c, b is returned which is not quite what I wanted.
The result I'm looking for is b, c
In the last line of SelectNested you only return the first parent:
yield return SelectNestedParents(parent, selector).FirstOrDefault();
You have to return all parents:
foreach (var p in SelectNestedParents(parent, selector))
return p;
Instead of using recursion you can use iteration which probably is more efficient:
public static IEnumerable<T> SelectNested<T>(this T source, Func<T, T> selector)
where T : class {
var current = source;
while (current != null) {
yield return current;
current = selector(current);
}
}
The following code should work as expected:
public static IEnumerable<T> SelectNested<T>()
{
if (source != null){
yield return source;
var parent = selector(source);
// Result of the recursive call is IEnumerable<T>
// so you need to iterate over it and return its content.
foreach (var parent in (SelectNested(selector(source))))
{
yield return parent;
}
}
}
Strictly speaking, your class looks to be a list, not a graph, since selector returns only one object not an enumeration of them. Thus recursion is not necessary.
public static IEnumerable<T> SelectNested<T>(this T source, Func<T, T> selector)
where T : class
{
while (source != null)
{
yield return source;
source = selector(source);
}
}
Say I have a method like this:
IEnumerable<record> GetSomeRecords()
{
while(...)
{
yield return aRecord
}
}
Now, lets say I have a caller that also returns an enumerable of the same type, something like this
IEnumerable<record> ParentGetSomeRecords()
{
// I want to do this, but for some reason, my brain locks right here
foreach(item in someItems)
yield return GetSomeRecords();
}
That code gets syntax error error because yield return wants a type record, and I'm returning an IEnumerable of records
I want one "flat" IEnumerable that flattens a nested loop of enumerables. It's making me crazy, becuase I know I've done this before, but I can't seem to remember what it was. got any hints?
Is this what you are after?
IEnumerable<record> ParentGetSomeRecords()
{
foreach(var item in someItems)
foreach(var record in GetSomeRecords())
yield return record;
}
As noted, this will only work for a single level of children but is the most equivalent of your example code.
Update
Some people seem to believe you want the ability to flatten a hierarchical structure. Here is an extension method which performs breadth-first flattening (get the siblings before children):
Coming from a single item:
[Pure]
public static IEnumerable<T> BreadthFirstFlatten<T>(this T source, Func<T, IEnumerable<T>> selector)
{
Contract.Requires(!ReferenceEquals(source, null));
Contract.Requires(selector != null);
Contract.Ensures(Contract.Result<IEnumerable<T>>() != null);
var pendingChildren = new List<T> {source};
while (pendingChildren.Any())
{
var localPending = pendingChildren.ToList();
pendingChildren.Clear();
foreach (var child in localPending)
{
yield return child;
var results = selector(child);
if (results != null)
pendingChildren.AddRange(results);
}
}
}
This can be used like so:
record rec = ...;
IEnumerable<record> flattened = rec.BreadthFirstFlatten(r => r.ChildRecords);
This will result in an IEnumerable<record> containing rec, all of recs children, all of the childrens children, etc etc..
If you are coming from a collection of records, use the following code:
[Pure]
private static IEnumerable<T> BreadthFirstFlatten<T, TResult>(IEnumerable<T> source, Func<T, TResult> selector, Action<ICollection<T>, TResult> addMethod)
{
Contract.Requires(source != null);
Contract.Requires(selector != null);
Contract.Requires(addMethod != null);
Contract.Ensures(Contract.Result<IEnumerable<T>>() != null);
var pendingChildren = new List<T>(source);
while (pendingChildren.Any())
{
var localPending = pendingChildren.ToList();
pendingChildren.Clear();
foreach (var child in localPending)
{
yield return child;
var results = selector(child);
if (!ReferenceEquals(results, null))
addMethod(pendingChildren, results);
}
}
}
[Pure]
public static IEnumerable<T> BreadthFirstFlatten<T>(this IEnumerable<T> source, Func<T, IEnumerable<T>> selector)
{
Contract.Requires(source != null);
Contract.Requires(selector != null);
Contract.Ensures(Contract.Result<IEnumerable<T>>() != null);
return BreadthFirstFlatten(source, selector, (collection, arg2) => collection.AddRange(arg2));
}
[Pure]
public static IEnumerable<T> BreadthFirstFlatten<T>(this IEnumerable<T> source, Func<T, T> selector)
{
Contract.Requires(source != null);
Contract.Requires(selector != null);
Contract.Ensures(Contract.Result<IEnumerable<T>>() != null);
return BreadthFirstFlatten(source, selector, (collection, arg2) => collection.Add(arg2));
}
These two extension methods can be used like so:
IEnumerable<records> records = ...;
IEnumerable<record> flattened = records.BreadthFirstFlatten(r => r.ChildRecords);
Or from the reverse direction:
IEnumerable<record> records = ...;
IEnumerable<record> flattened = records.BreadthFirstFlatten(r => r.ParentRecords);
All of these extension methods are iterative so not limited by the stack size.
I have a whole host of these types of methods, including pre-order and post-order depth-first traversal, if you wish to see them, I will make a repo and upload them somewhere :)
How about:
IEnumerable<record> ParentGetSomeRecords()
{
var nestedEnumerable = <whatever the heck gets your nested set>;
// SelectMany with an identity flattens
// IEnumerable<IEnumerable<T>> to just IEnumerable<T>
return nestedEnumerable.SelectMany(rec => rec);
}
Inefficient, but you could use this:
List<Record> rcrdList = new List<Record>();
foreach (var item in someItems)
{
rcrdList.AddRange(item.GetSomeRecords());
}
return rcrdList;
I have a problem with a List of objects ...
This List contains objects which themselves contain objects, and so on ... (all objects are of the same type)
My objects looks like that :
public class MyObject (...)
{
...
public MyObject[] Object;
...
}
I'd like to change some variables of these objects (according to certain parameters), and to do that I think using LINQ.
My problem is that I do not really know how to do something that will pass through ALL my recursive List, regardless of their level.
I hope I was as clear as possible.
Thank you in advance for your help.
You can write a simple recursive method to do what you want easily enough:
public static void Touch(MyObject obj, string otherParameter)
{
obj.Value = otherParameter;
foreach (var child in obj.Object)
{
Touch(child, otherParameter);
}
}
If you really, really want a more LINQ-esque method, or you do this often enough to need a more generic approach, you could use something like this:
public static IEnumerable<T> FlattenTree<T>(IEnumerable<T> source, Func<T, IEnumerable<T>> selector)
{
//you could change this to a Queue or any other data structure
//to change the type of traversal from depth first to breath first or whatever
var stack = new Stack<T>();
while (stack.Any())
{
T next = stack.Pop();
yield return next;
foreach (T child in selector(next))
stack.Push(child);
}
}
You could then use it like:
MyObject root = new MyObject();
var allNodes = FlattenTree(new[] { root }, node => node.Object);
foreach (var node in allNodes)
{
node.Value = "value";
}
You could use this recursive extension method:
public static IEnumerable<T> Traverse<T>(this IEnumerable<T> source, Func<T, IEnumerable<T>> fnRecurse)
{
foreach (T item in source)
{
yield return item;
IEnumerable<T> seqRecurse = fnRecurse(item);
if (seqRecurse != null)
{
foreach (T itemRecurse in Traverse(seqRecurse, fnRecurse))
{
yield return itemRecurse;
}
}
}
}
You can use it in this way:
var allObj = list.Traverse(o => o.Object);
foreach (MyObject o in allObj)
{
// do something
}
It's handy because it's generic and works with any type and also because it's using deferred execution.
Maybe simply something like this:
static void AddRecursively(MyObject obj, List<MyObject> listToAddTo)
{
listToAddTo.Add(obj);
foreach (var o in obj.Object)
AddRecursively(o, listToAddTo);
}
I suggest to use this extension method that applies an action to all the items recursively
public static void ForEach<T>(this IEnumerable<T> source,
Func<T, IEnumerable<T>> getChildren,
Action<T> action)
{
if (source == null) {
return;
}
foreach (T item in source) {
action(item);
IEnumerable<T> children = getChildren(item);
children.ForEach(getChildren, action);
}
}
You would apply it to your list like this
myObjectList.ForEach(x => x.Object, x => x.Value = "new value");
The first paramter tells ForEach how to access the nested objects. The second parameter tells what to do with each item.
Not sure how to call it, but say you have a class that looks like this:
class Person
{
public string Name;
public IEnumerable<Person> Friends;
}
You then have a person and you want to "unroll" this structure recursively so you end up with a single list of all people without duplicates.
How would you do this? I have already made something that seems to be working, but I am curious to see how others would do it and especially if there is something built-in to Linq you can use in a clever way to solve this little problem :)
Here is my solution:
public static IEnumerable<T> SelectRecursive<T>(this IEnumerable<T> subjects, Func<T, IEnumerable<T>> selector)
{
// Stop if subjects are null or empty
if(subjects == null)
yield break;
// For each subject
foreach(var subject in subjects)
{
// Yield it
yield return subject;
// Then yield all its decendants
foreach (var decendant in SelectRecursive(selector(subject), selector))
yield return decendant;
}
}
Would be used something like this:
var people = somePerson.SelectRecursive(x => x.Friends);
I don't believe there's anything built into LINQ to do this.
There's a problem with doing it recursively like this - you end up creating a large number of iterators. This can be quite inefficient if the tree is deep. Wes Dyer and Eric Lippert have both blogged about this.
You can remove this inefficiency by removing the direct recursion. For example:
public static IEnumerable<T> SelectRecursive<T>(this IEnumerable<T> subjects,
Func<T, IEnumerable<T>> selector)
{
if (subjects == null)
{
yield break;
}
Queue<T> stillToProcess = new Queue<T>(subjects);
while (stillToProcess.Count > 0)
{
T item = stillToProcess.Dequeue();
yield return item;
foreach (T child in selector(item))
{
stillToProcess.Enqueue(child);
}
}
}
This will also change the iteration order - it becomes breadth-first instead of depth-first; rewriting it to still be depth-first is tricky. I've also changed it to not use Any() - this revised version won't evaluate any sequence more than once, which can be handy in some scenarios. This does have one problem, mind you - it will take more memory, due to the queuing. We could probably alleviate this by storing a queue of iterators instead of items, but I'm not sure offhand... it would certainly be more complicated.
One point to note (also noted by ChrisW while I was looking up the blog posts :) - if you have any cycles in your friends list (i.e. if A has B, and B has A) then you'll recurse forever.
I found this question as I was looking for and thinking about a similar solution - in my case creating an efficient IEnumerable<Control> for ASP.NET UI controls. The recursive yield I had is fast but I knew that could have extra cost, since the deeper the control structure the longer it could take. Now I know this is O(n log n).
The solution given here provides some answer but, as discussed in the comments, it does change the order (which the OP did not care about). I realized that to preserve the order as given by the OP and as I needed, neither a simple Queue (as Jon used) nor Stack would work since all the parent objects would be yielded first and then any children after them (or vice-versa).
To resolve this and preserve the order I realized the solution would simply be to put the Enumerator itself on a Stack. To use the OPs original question it would look like this:
public static IEnumerable<T> SelectRecursive<T>(this IEnumerable<T> subjects, Func<T, IEnumerable<T>> selector)
{
if (subjects == null)
yield break;
var stack = new Stack<IEnumerator<T>>();
stack.Push(subjects.GetEnumerator());
while (stack.Count > 0)
{
var en = stack.Peek();
if (en.MoveNext())
{
var subject = en.Current;
yield return subject;
stack.Push(selector(subject).GetEnumerator());
}
else
{
stack.Pop().Dispose();
}
}
}
I use stack.Peek here to keep from having to push the same enumerator back on to the stack as this is likely to be the more frequent operation, expecting that enumerator to provide more than one item.
This creates the same number of enumerators as in the recursive version but will likely be fewer new objects than putting all the subjects in a queue or stack and continuing to add any descendant subjects. This is O(n) time as each enumerator stands on its own (in the recursive version an implicit call to one MoveNext executes MoveNext on the child enumerators to the current depth in the recursion stack).
You could use a non-recursive method like this as well:
HashSet<Person> GatherAll (Person p) {
Stack<Person> todo = new Stack<Person> ();
HashSet<Person> results = new HashSet<Person> ();
todo.Add (p); results.Add (p);
while (todo.Count > 0) {
Person p = todo.Pop ();
foreach (Person f in p.Friends)
if (results.Add (f)) todo.Add (f);
}
return results;
}
This should handle cycles properly as well. I am starting with a single person, but you could easily expand this to start with a list of persons.
Here's an implementation that:
Does a depth first recursive select,
Doesn't require double iteration of the child collections,
Doesn't use intermediate collections for the selected elements,
Doesn't handle cycles,
Can do it backwards.
public static IEnumerable<T> SelectRecursive<T>(this IEnumerable<T> rootItems, Func<T, IEnumerable<T>> selector)
{
return new RecursiveEnumerable<T>(rootItems, selector, false);
}
public static IEnumerable<T> SelectRecursiveReverse<T>(this IEnumerable<T> rootItems, Func<T, IEnumerable<T>> selector)
{
return new RecursiveEnumerable<T>(rootItems, selector, true);
}
class RecursiveEnumerable<T> : IEnumerable<T>
{
public RecursiveEnumerable(IEnumerable<T> rootItems, Func<T, IEnumerable<T>> selector, bool reverse)
{
_rootItems = rootItems;
_selector = selector;
_reverse = reverse;
}
IEnumerable<T> _rootItems;
Func<T, IEnumerable<T>> _selector;
bool _reverse;
public IEnumerator<T> GetEnumerator()
{
return new Enumerator(this);
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
class Enumerator : IEnumerator<T>
{
public Enumerator(RecursiveEnumerable<T> owner)
{
_owner = owner;
Reset();
}
RecursiveEnumerable<T> _owner;
T _current;
Stack<IEnumerator<T>> _stack = new Stack<IEnumerator<T>>();
public T Current
{
get
{
if (_stack == null || _stack.Count == 0)
throw new InvalidOperationException();
return _current;
}
}
public void Dispose()
{
_current = default(T);
if (_stack != null)
{
while (_stack.Count > 0)
{
_stack.Pop().Dispose();
}
_stack = null;
}
}
object System.Collections.IEnumerator.Current
{
get { return Current; }
}
public bool MoveNext()
{
if (_owner._reverse)
return MoveReverse();
else
return MoveForward();
}
public bool MoveForward()
{
// First time?
if (_stack == null)
{
// Setup stack
_stack = new Stack<IEnumerator<T>>();
// Start with the root items
_stack.Push(_owner._rootItems.GetEnumerator());
}
// Process enumerators on the stack
while (_stack.Count > 0)
{
// Get the current one
var se = _stack.Peek();
// Next please...
if (se.MoveNext())
{
// Store it
_current = se.Current;
// Get child items
var childItems = _owner._selector(_current);
if (childItems != null)
{
_stack.Push(childItems.GetEnumerator());
}
return true;
}
// Finished with the enumerator
se.Dispose();
_stack.Pop();
}
// Finished!
return false;
}
public bool MoveReverse()
{
// First time?
if (_stack == null)
{
// Setup stack
_stack = new Stack<IEnumerator<T>>();
// Start with the root items
_stack.Push(_owner._rootItems.Reverse().GetEnumerator());
}
// Process enumerators on the stack
while (_stack.Count > 0)
{
// Get the current one
var se = _stack.Peek();
// Next please...
if (se.MoveNext())
{
// Get child items
var childItems = _owner._selector(se.Current);
if (childItems != null)
{
_stack.Push(childItems.Reverse().GetEnumerator());
continue;
}
// Store it
_current = se.Current;
return true;
}
// Finished with the enumerator
se.Dispose();
_stack.Pop();
if (_stack.Count > 0)
{
_current = _stack.Peek().Current;
return true;
}
}
// Finished!
return false;
}
public void Reset()
{
Dispose();
}
}
}
use the Aggregate extension...
List<Person> persons = GetPersons();
List<Person> result = new List<Person>();
persons.Aggregate(result,SomeFunc);
private static List<Person> SomeFunc(List<Person> arg1,Person arg2)
{
arg1.Add(arg2)
arg1.AddRange(arg2.Persons);
return arg1;
}
Recursion is always fun. Perhaps you could simplify your code to:
public static IEnumerable<T> SelectRecursive<T>(this IEnumerable<T> subjects, Func<T, IEnumerable<T>> selector) {
// Stop if subjects are null or empty
if (subjects == null || !subjects.Any())
return Enumerable.Empty<T>();
// Gather a list of all (selected) child elements of all subjects
var subjectChildren = subjects.SelectMany(selector);
// Jump into the recursion for each of the child elements
var recursiveChildren = SelectRecursive(subjectChildren, selector);
// Combine the subjects with all of their (recursive child elements).
// The union will remove any direct parent-child duplicates.
// Endless loops due to circular references are however still possible.
return subjects.Union(recursiveChildren);
}
It will result in less duplicates than your original code. However their might still be duplicates causing an endless loop, the union will only prevent direct parent(s)-child(s) duplicates.
And the order of the items will be different from yours :)
Edit: Changed the last line of code to three statements and added a bit more documentation.
While its great to have IEnumerable when there might be a lot of data, its worth remembering the classic approach of recursively adding to a list.
That can be as simple as this (I've left out selector; just demonstrating recursively adding to an output list):
class Node
{
public readonly List<Node> Children = new List<Node>();
public List<Node> Flatten()
{
var all = new List<Node>();
Flatten(ref all);
return all;
}
public void Flatten(List<Node> all)
{
all.Add(this);
foreach (var child in Children)
child.Flatten(all);
}
}
usage:
Node rootNode = ...;
...
var all = rootNode.Flatten();