I could probably write this myself, but the specific way I'm trying to accomplish it is throwing me off. I'm trying to write a generic extension method similar to the others introduced in .NET 3.5 that will take a nested IEnumerable of IEnumerables (and so on) and flatten it into one IEnumerable. Anyone have any ideas?
Specifically, I'm having trouble with the syntax of the extension method itself so that I can work on a flattening algorithm.
Here's an extension that might help. It will traverse all nodes in your hierarchy of objects and pick out the ones that match a criteria. It assumes that each object in your hierarchy has a collection property that holds its child objects.
Here's the extension:
/// Traverses an object hierarchy and return a flattened list of elements
/// based on a predicate.
///
/// TSource: The type of object in your collection.</typeparam>
/// source: The collection of your topmost TSource objects.</param>
/// selectorFunction: A predicate for choosing the objects you want.
/// getChildrenFunction: A function that fetches the child collection from an object.
/// returns: A flattened list of objects which meet the criteria in selectorFunction.
public static IEnumerable<TSource> Map<TSource>(
this IEnumerable<TSource> source,
Func<TSource, bool> selectorFunction,
Func<TSource, IEnumerable<TSource>> getChildrenFunction)
{
// Add what we have to the stack
var flattenedList = source.Where(selectorFunction);
// Go through the input enumerable looking for children,
// and add those if we have them
foreach (TSource element in source)
{
flattenedList = flattenedList.Concat(
getChildrenFunction(element).Map(selectorFunction,
getChildrenFunction)
);
}
return flattenedList;
}
Examples (Unit Tests):
First we need an object and a nested object hierarchy.
A simple node class
class Node
{
public int NodeId { get; set; }
public int LevelId { get; set; }
public IEnumerable<Node> Children { get; set; }
public override string ToString()
{
return String.Format("Node {0}, Level {1}", this.NodeId, this.LevelId);
}
}
And a method to get a 3-level deep hierarchy of nodes
private IEnumerable<Node> GetNodes()
{
// Create a 3-level deep hierarchy of nodes
Node[] nodes = new Node[]
{
new Node
{
NodeId = 1,
LevelId = 1,
Children = new Node[]
{
new Node { NodeId = 2, LevelId = 2, Children = new Node[] {} },
new Node
{
NodeId = 3,
LevelId = 2,
Children = new Node[]
{
new Node { NodeId = 4, LevelId = 3, Children = new Node[] {} },
new Node { NodeId = 5, LevelId = 3, Children = new Node[] {} }
}
}
}
},
new Node { NodeId = 6, LevelId = 1, Children = new Node[] {} }
};
return nodes;
}
First Test: flatten the hierarchy, no filtering
[Test]
public void Flatten_Nested_Heirachy()
{
IEnumerable<Node> nodes = GetNodes();
var flattenedNodes = nodes.Map(
p => true,
(Node n) => { return n.Children; }
);
foreach (Node flatNode in flattenedNodes)
{
Console.WriteLine(flatNode.ToString());
}
// Make sure we only end up with 6 nodes
Assert.AreEqual(6, flattenedNodes.Count());
}
This will show:
Node 1, Level 1
Node 6, Level 1
Node 2, Level 2
Node 3, Level 2
Node 4, Level 3
Node 5, Level 3
Second Test: Get a list of nodes that have an even-numbered NodeId
[Test]
public void Only_Return_Nodes_With_Even_Numbered_Node_IDs()
{
IEnumerable<Node> nodes = GetNodes();
var flattenedNodes = nodes.Map(
p => (p.NodeId % 2) == 0,
(Node n) => { return n.Children; }
);
foreach (Node flatNode in flattenedNodes)
{
Console.WriteLine(flatNode.ToString());
}
// Make sure we only end up with 3 nodes
Assert.AreEqual(3, flattenedNodes.Count());
}
This will show:
Node 6, Level 1
Node 2, Level 2
Node 4, Level 3
Hmm... I'm not sure exactly what you want here, but here's a "one level" option:
public static IEnumerable<TElement> Flatten<TElement,TSequence> (this IEnumerable<TSequence> sequences)
where TSequence : IEnumerable<TElement>
{
foreach (TSequence sequence in sequences)
{
foreach(TElement element in sequence)
{
yield return element;
}
}
}
If that's not what you want, could you provide the signature of what you do want? If you don't need a generic form, and you just want to do the kind of thing that LINQ to XML constructors do, that's reasonably simple - although the recursive use of iterator blocks is relatively inefficient. Something like:
static IEnumerable Flatten(params object[] objects)
{
// Can't easily get varargs behaviour with IEnumerable
return Flatten((IEnumerable) objects);
}
static IEnumerable Flatten(IEnumerable enumerable)
{
foreach (object element in enumerable)
{
IEnumerable candidate = element as IEnumerable;
if (candidate != null)
{
foreach (object nested in candidate)
{
yield return nested;
}
}
else
{
yield return element;
}
}
}
Note that that will treat a string as a sequence of chars, however - you may want to special-case strings to be individual elements instead of flattening them, depending on your use case.
Does that help?
I thought I'd share a complete example with error handling and a single-logic apporoach.
Recursive flattening is as simple as:
LINQ version
public static class IEnumerableExtensions
{
public static IEnumerable<T> SelectManyRecursive<T>(this IEnumerable<T> source, Func<T, IEnumerable<T>> selector)
{
if (source == null) throw new ArgumentNullException("source");
if (selector == null) throw new ArgumentNullException("selector");
return !source.Any() ? source :
source.Concat(
source
.SelectMany(i => selector(i).EmptyIfNull())
.SelectManyRecursive(selector)
);
}
public static IEnumerable<T> EmptyIfNull<T>(this IEnumerable<T> source)
{
return source ?? Enumerable.Empty<T>();
}
}
Non-LINQ version
public static class IEnumerableExtensions
{
public static IEnumerable<T> SelectManyRecursive<T>(this IEnumerable<T> source, Func<T, IEnumerable<T>> selector)
{
if (source == null) throw new ArgumentNullException("source");
if (selector == null) throw new ArgumentNullException("selector");
foreach (T item in source)
{
yield return item;
var children = selector(item);
if (children == null)
continue;
foreach (T descendant in children.SelectManyRecursive(selector))
{
yield return descendant;
}
}
}
}
Design decisions
I decided to:
disallow flattening of a null IEnumerable, this can be changed by removing exception throwing and:
adding source = source.EmptyIfNull(); before return in the 1st version
adding if (source != null) before foreach in the 2nd version
allow returning of a null collection by the selector - this way I'm removing responsibility from the caller to assure the children list isn't empty, this can be changed by:
removing .EmptyIfNull() in the first version - note that SelectMany will fail if null is returned by selector
removing if (children == null) continue; in the second version - note that foreach will fail on a null IEnumerable parameter
allow filtering children with .Where clause on the caller side or within the children selector rather than passing a children filter selector parameter:
it won't impact the efficiency because in both versions it is a deferred call
it would be mixing another logic with the method and I prefer to keep the logic separated
Sample use
I'm using this extension method in LightSwitch to obtain all controls on the screen:
public static class ScreenObjectExtensions
{
public static IEnumerable<IContentItemProxy> FindControls(this IScreenObject screen)
{
var model = screen.Details.GetModel();
return model.GetChildItems()
.SelectManyRecursive(c => c.GetChildItems())
.OfType<IContentItemDefinition>()
.Select(c => screen.FindControl(c.Name));
}
}
Here is a modified Jon Skeet's answer to allow more than "one level":
static IEnumerable Flatten(IEnumerable enumerable)
{
foreach (object element in enumerable)
{
IEnumerable candidate = element as IEnumerable;
if (candidate != null)
{
foreach (object nested in Flatten(candidate))
{
yield return nested;
}
}
else
{
yield return element;
}
}
}
disclaimer: I don't know C#.
The same in Python:
#!/usr/bin/env python
def flatten(iterable):
for item in iterable:
if hasattr(item, '__iter__'):
for nested in flatten(item):
yield nested
else:
yield item
if __name__ == '__main__':
for item in flatten([1,[2, 3, [[4], 5]], 6, [[[7]]], [8]]):
print(item, end=" ")
It prints:
1 2 3 4 5 6 7 8
Isn't that what [SelectMany][1] is for?
enum1.SelectMany(
a => a.SelectMany(
b => b.SelectMany(
c => c.Select(
d => d.Name
)
)
)
);
Function:
public static class MyExtentions
{
public static IEnumerable<T> RecursiveSelector<T>(this IEnumerable<T> nodes, Func<T, IEnumerable<T>> selector)
{
if(nodes.Any() == false)
{
return nodes;
}
var descendants = nodes
.SelectMany(selector)
.RecursiveSelector(selector);
return nodes.Concat(descendants);
}
}
Usage:
var ar = new[]
{
new Node
{
Name = "1",
Chilren = new[]
{
new Node
{
Name = "11",
Children = new[]
{
new Node
{
Name = "111",
}
}
}
}
}
};
var flattened = ar.RecursiveSelector(x => x.Children).ToList();
Okay here's another version which is combined from about 3 answers above.
Recursive. Uses yield. Generic. Optional filter predicate. Optional selection function. About as concise as I could make it.
public static IEnumerable<TNode> Flatten<TNode>(
this IEnumerable<TNode> nodes,
Func<TNode, bool> filterBy = null,
Func<TNode, IEnumerable<TNode>> selectChildren = null
)
{
if (nodes == null) yield break;
if (filterBy != null) nodes = nodes.Where(filterBy);
foreach (var node in nodes)
{
yield return node;
var children = (selectChildren == null)
? node as IEnumerable<TNode>
: selectChildren(node);
if (children == null) continue;
foreach (var child in children.Flatten(filterBy, selectChildren))
{
yield return child;
}
}
}
Usage:
// With filter predicate, with selection function
var flatList = nodes.Flatten(n => n.IsDeleted == false, n => n.Children);
The SelectMany extension method does this already.
Projects each element of a sequence to
an IEnumerable<(Of <(T>)>) and
flattens the resulting sequences into
one sequence.
I had to implement mine from scratch because all of the provided solutions would break in case there is a loop i.e. a child that points to its ancestor. If you have the same requirements as mine please take a look at this (also let me know if my solution would break in any special circumstances):
How to use:
var flattenlist = rootItem.Flatten(obj => obj.ChildItems, obj => obj.Id)
Code:
public static class Extensions
{
/// <summary>
/// This would flatten out a recursive data structure ignoring the loops. The end result would be an enumerable which enumerates all the
/// items in the data structure regardless of the level of nesting.
/// </summary>
/// <typeparam name="T">Type of the recursive data structure</typeparam>
/// <param name="source">Source element</param>
/// <param name="childrenSelector">a function that returns the children of a given data element of type T</param>
/// <param name="keySelector">a function that returns a key value for each element</param>
/// <returns>a faltten list of all the items within recursive data structure of T</returns>
public static IEnumerable<T> Flatten<T>(this IEnumerable<T> source,
Func<T, IEnumerable<T>> childrenSelector,
Func<T, object> keySelector) where T : class
{
if (source == null)
throw new ArgumentNullException("source");
if (childrenSelector == null)
throw new ArgumentNullException("childrenSelector");
if (keySelector == null)
throw new ArgumentNullException("keySelector");
var stack = new Stack<T>( source);
var dictionary = new Dictionary<object, T>();
while (stack.Any())
{
var currentItem = stack.Pop();
var currentkey = keySelector(currentItem);
if (dictionary.ContainsKey(currentkey) == false)
{
dictionary.Add(currentkey, currentItem);
var children = childrenSelector(currentItem);
if (children != null)
{
foreach (var child in children)
{
stack.Push(child);
}
}
}
yield return currentItem;
}
}
/// <summary>
/// This would flatten out a recursive data structure ignoring the loops. The end result would be an enumerable which enumerates all the
/// items in the data structure regardless of the level of nesting.
/// </summary>
/// <typeparam name="T">Type of the recursive data structure</typeparam>
/// <param name="source">Source element</param>
/// <param name="childrenSelector">a function that returns the children of a given data element of type T</param>
/// <param name="keySelector">a function that returns a key value for each element</param>
/// <returns>a faltten list of all the items within recursive data structure of T</returns>
public static IEnumerable<T> Flatten<T>(this T source,
Func<T, IEnumerable<T>> childrenSelector,
Func<T, object> keySelector) where T: class
{
return Flatten(new [] {source}, childrenSelector, keySelector);
}
}
Since yield is not available in VB and LINQ provides both deferred execution and a concise syntax, you can also use.
<Extension()>
Public Function Flatten(Of T)(ByVal objects As Generic.IEnumerable(Of T), ByVal selector As Func(Of T, Generic.IEnumerable(Of T))) As Generic.IEnumerable(Of T)
If(objects.Any()) Then
Return objects.Union(objects.Select(selector).Where(e => e != null).SelectMany(e => e)).Flatten(selector))
Else
Return objects
End If
End Function
public static class Extensions{
public static IEnumerable<T> Flatten<T>(this IEnumerable<T> objects, Func<T, IEnumerable<T>> selector) where T:Component{
if(objects.Any()){
return objects.Union(objects.Select(selector).Where(e => e != null).SelectMany(e => e).Flatten(selector));
}
return objects;
}
}
edited to include:
empty enumerable per https://stackoverflow.com/a/30325216/107683,
null enumerable per https://stackoverflow.com/a/39338919/107683
C# implementation.
static class EnumerableExtensions
{
public static IEnumerable<T> Flatten<T>(this IEnumerable<IEnumerable<T>> sequence)
{
foreach(var child in sequence)
foreach(var item in child)
yield return item;
}
}
Maybe like this? Or do you mean that it could potentially be infintly deep?
class PageViewModel {
public IEnumerable<PageViewModel> ChildrenPages { get; set; }
}
Func<IEnumerable<PageViewModel>, IEnumerable<PageViewModel>> concatAll = null;
concatAll = list => list.SelectMany(l => l.ChildrenPages.Any() ?
concatAll(l.ChildrenPages).Union(new[] { l }) : new[] { l });
var allPages = concatAll(source).ToArray();
Basicly, you need to have a master IENumerable that is outside of your recursive function, then in your recursive function (Psuedo-code)
private void flattenList(IEnumerable<T> list)
{
foreach (T item in list)
{
masterList.Add(item);
if (item.Count > 0)
{
this.flattenList(item);
}
}
}
Though I'm really not sure what you mean by IEnumerable nested in an IEnumerable...whats within that? How many levels of nesting? Whats the final type? obviously my code isn't correct, but I hope it gets you thinking.
Related
I need to search a tree for data that could be anywhere in the tree. How can this be done with linq?
class Program
{
static void Main(string[] args) {
var familyRoot = new Family() {Name = "FamilyRoot"};
var familyB = new Family() {Name = "FamilyB"};
familyRoot.Children.Add(familyB);
var familyC = new Family() {Name = "FamilyC"};
familyB.Children.Add(familyC);
var familyD = new Family() {Name = "FamilyD"};
familyC.Children.Add(familyD);
//There can be from 1 to n levels of families.
//Search all children, grandchildren, great grandchildren etc, for "FamilyD" and return the object.
}
}
public class Family {
public string Name { get; set; }
List<Family> _children = new List<Family>();
public List<Family> Children {
get { return _children; }
}
}
That's an extension to It'sNotALie.s answer.
public static class Linq
{
public static IEnumerable<T> Flatten<T>(this T source, Func<T, IEnumerable<T>> selector)
{
return selector(source).SelectMany(c => Flatten(c, selector))
.Concat(new[] { source });
}
}
Sample test usage:
var result = familyRoot.Flatten(x => x.Children).FirstOrDefault(x => x.Name == "FamilyD");
Returns familyD object.
You can make it work on IEnumerable<T> source too:
public static IEnumerable<T> Flatten<T>(this IEnumerable<T> source, Func<T, IEnumerable<T>> selector)
{
return source.SelectMany(x => Flatten(x, selector))
.Concat(source);
}
Another solution without recursion...
var result = FamilyToEnumerable(familyRoot)
.Where(f => f.Name == "FamilyD");
IEnumerable<Family> FamilyToEnumerable(Family f)
{
Stack<Family> stack = new Stack<Family>();
stack.Push(f);
while (stack.Count > 0)
{
var family = stack.Pop();
yield return family;
foreach (var child in family.Children)
stack.Push(child);
}
}
Simple:
familyRoot.Flatten(f => f.Children);
//you can do whatever you want with that sequence there.
//for example you could use Where on it and find the specific families, etc.
IEnumerable<T> Flatten<T>(this T source, Func<T, IEnumerable<T>> selector)
{
return selector(source).SelectMany(c => Flatten(selector(c), selector))
.Concat(new[]{source});
}
So, the simplest option is to write a function that traverses your hierarchy and produces a single sequence. This then goes at the start of your LINQ operations, e.g.
IEnumerable<T> Flatten<T>(this T source)
{
foreach(var item in source) {
yield item;
foreach(var child in Flatten(item.Children)
yield child;
}
}
To call simply: familyRoot.Flatten().Where(n => n.Name == "Bob");
A slight alternative would give you a way to quickly ignore a whole branch:
IEnumerable<T> Flatten<T>(this T source, Func<T, bool> predicate)
{
foreach(var item in source) {
if (predicate(item)) {
yield item;
foreach(var child in Flatten(item.Children)
yield child;
}
}
Then you could do things like: family.Flatten(n => n.Children.Count > 2).Where(...)
I like Kenneth Bo Christensen's answer using stack, it works great, it is easy to read and it is fast (and doesn't use recursion).
The only unpleasant thing is that it reverses the order of child items (because stack is FIFO). If sort order doesn't matter to you then it's ok.
If it does, sorting can be achieved easily using selector(current).Reverse() in the foreach loop (the rest of the code is the same as in Kenneth's original post)...
public static IEnumerable<T> Flatten<T>(this T source, Func<T, IEnumerable<T>> selector)
{
var stack = new Stack<T>();
stack.Push(source);
while (stack.Count > 0)
{
var current = stack.Pop();
yield return current;
foreach (var child in selector(current).Reverse())
stack.Push(child);
}
}
Well, I guess the way is to go with the technique of working with hierarchical structures:
You need an anchor to make
You need the recursion part
// Anchor
rootFamily.Children.ForEach(childFamily =>
{
if (childFamily.Name.Contains(search))
{
// Your logic here
return;
}
SearchForChildren(childFamily);
});
// Recursion
public void SearchForChildren(Family childFamily)
{
childFamily.Children.ForEach(_childFamily =>
{
if (_childFamily.Name.Contains(search))
{
// Your logic here
return;
}
SearchForChildren(_childFamily);
});
}
I have tried two of the suggested codes and made the code a bit more clear:
public static IEnumerable<T> Flatten1<T>(this T source, Func<T, IEnumerable<T>> selector)
{
return selector(source).SelectMany(c => Flatten1(c, selector)).Concat(new[] { source });
}
public static IEnumerable<T> Flatten2<T>(this T source, Func<T, IEnumerable<T>> selector)
{
var stack = new Stack<T>();
stack.Push(source);
while (stack.Count > 0)
{
var current = stack.Pop();
yield return current;
foreach (var child in selector(current))
stack.Push(child);
}
}
Flatten2() seems to be a little bit faster but its a close run.
Some further variants on the answers of It'sNotALie., MarcinJuraszek and DamienG.
First, the former two give a counterintuitive ordering. To get a nice tree-traversal ordering to the results, just invert the concatenation (put the "source" first).
Second, if you are working with an expensive source like EF, and you want to limit entire branches, Damien's suggestion that you inject the predicate is a good one and can still be done with Linq.
Finally, for an expensive source it may also be good to pre-select the fields of interest from each node with an injected selector.
Putting all these together:
public static IEnumerable<R> Flatten<T,R>(this T source, Func<T, IEnumerable<T>> children
, Func<T, R> selector
, Func<T, bool> branchpredicate = null
) {
if (children == null) throw new ArgumentNullException("children");
if (selector == null) throw new ArgumentNullException("selector");
var pred = branchpredicate ?? (src => true);
if (children(source) == null) return new[] { selector(source) };
return new[] { selector(source) }
.Concat(children(source)
.Where(pred)
.SelectMany(c => Flatten(c, children, selector, pred)));
}
I'm trying to write an extension method that is supposed to traverse an object graph and return all visited objects.
I'm not sure if my approach is the best, so please do comment on that. Also yield is frying my brain... I'm sure the answer is obvious :/
Model
public class MyClass
{
public MyClass Parent {get;set;}
}
Method
public static IEnumerable<T> SelectNested<T>
(this T source, Func<T, T> selector)
where T : class
{
yield return source;
var parent = selector(source);
if (parent == null)
yield break;
yield return SelectNestedParents(parent, selector).FirstOrDefault();
}
Usage
var list = myObject.SelectNested(x => x.Parent);
The problem
It's almost working. But it only visits 2 objects. It self and the parent.
So given this graph c -> b -> a starting from c. c, b is returned which is not quite what I wanted.
The result I'm looking for is b, c
In the last line of SelectNested you only return the first parent:
yield return SelectNestedParents(parent, selector).FirstOrDefault();
You have to return all parents:
foreach (var p in SelectNestedParents(parent, selector))
return p;
Instead of using recursion you can use iteration which probably is more efficient:
public static IEnumerable<T> SelectNested<T>(this T source, Func<T, T> selector)
where T : class {
var current = source;
while (current != null) {
yield return current;
current = selector(current);
}
}
The following code should work as expected:
public static IEnumerable<T> SelectNested<T>()
{
if (source != null){
yield return source;
var parent = selector(source);
// Result of the recursive call is IEnumerable<T>
// so you need to iterate over it and return its content.
foreach (var parent in (SelectNested(selector(source))))
{
yield return parent;
}
}
}
Strictly speaking, your class looks to be a list, not a graph, since selector returns only one object not an enumeration of them. Thus recursion is not necessary.
public static IEnumerable<T> SelectNested<T>(this T source, Func<T, T> selector)
where T : class
{
while (source != null)
{
yield return source;
source = selector(source);
}
}
I find myself regularly writing recursive IEnumerable<T> iterators to implement the same "Descendants" pattern as provided by, for example, XContainer.Descendants. The pattern I keep implementing is as follows, given a type Foo with a single-level iterator called Children:
public static IEnumerable<Foo> Descendants(this Foo root) {
foreach (var child in root.Children()) {
yield return child;
foreach (var subchild in child.Descendants()) {
yield return subchild;
}
}
}
This old StackOverflow question suggests the same pattern. But for some reason it feels weird to me to have to reference three levels of heirarchy (root, child, and subchild). Can this fundamental depth-first recursion pattern be further reduced? Or is this an algorithmic primitive of sorts?
The best I can come up with is to abstract the pattern to a generic extension. This doesn't reduce the logic of the iterator pattern presented above, but it does remove the requirement of defining a Descendants method for multiple specific classes. On the downside, this adds an extension method to Object itself, which is a little smelly:
public static IEnumerable<T> SelectRecurse<T>(
this T root, Func<T, IEnumerable<T>> enumerator) {
foreach (T item in enumerator(root))
{
yield return item;
foreach (T subitem in item.SelectRecurse(enumerator))
{
yield return subitem;
}
}
}
// Now we can just write:
foreach(var item in foo.SelectRecurse(f => f.Children())) { /* do stuff */ }
You can use an explicit stack, rather than implicitly using the thread's call stack, to store the data that you are using. This can even be generalized to a Traverse method that just accepts a delegate to represent the "get my children" call:
public static IEnumerable<T> Traverse<T>(
this IEnumerable<T> source
, Func<T, IEnumerable<T>> childrenSelector)
{
var stack = new Stack<T>(source);
while (stack.Any())
{
var next = stack.Pop();
yield return next;
foreach (var child in childrenSelector(next))
stack.Push(child);
}
}
Because this isn't recursive, and thus isn't creating the state machines constantly, it will perform quite a bit better.
Side note, if you want a Breath First Search just use a Queue instead of a Stack. If you want a Best First Search use a priority queue.
To ensure that siblings are returned in the same order as they are returned from the selecor's order, rather than the reverse, just add a Reverse call to the result of childrenSelector.
I think this is a good question. The best explanation I have for why you need two loops: We need to recognize the fact that each item is converted to become multiple items (itself, and all its descendants). This means that we do not map one-to-one (like Select) but one-to-many (SelectMany).
We could write it like this:
public static IEnumerable<Foo> Descendants(this IEnumerable<Foo> items) {
foreach (var item in items) {
yield return item;
foreach (var subitem in item.Children().Descendants())
yield return subitem;
}
}
Or like this:
public static IEnumerable<Foo> Descendants(Foo root) {
var children = root.Children();
var subchildren = children.SelectMany(c => c.Descendants());
return children.Concat(subchildren);
}
Or like this:
public static IEnumerable<Foo> Descendants(this IEnumerable<Foo> items) {
var children = items.SelectMany(c => c.Descendants());
return items.Concat(children);
}
The versions taking an IEnumerable<Foo> must be invoked on root.Children().
I think all of these rewrites expose a different way of looking at the problem. On the other hand, they all have two nested loops. The loops can be hidden in helper functions but they still exist.
I would manage this with a List:
public static IEnumerable<Foo> Descendants(this Foo root) {
List<Foo> todo = new List<Foo>();
todo.AddRange(root.Children());
while(todo.Count > 0)
{
var first = todo[0];
todo.RemoveAt(0);
todo.InsertRange(0,first.Children());
yield return first;
}
}
Not recursive, so shouldn't blow the stack. You just always add more work for yourself onto the front of the list and so you achieve the depth-first traversal.
Both Damien_the_Unbeliever and Servy have presented versions of an algorithm that avoid creating a recursive call stack by using collections of one type or another. Damien's use of a List could cause poor performance to inserts at the head of the list, while Servy's use a of stack will cause nested elements to be returned in reverse order. I believe manually implementing a one-way linked list will maintain Servy's performance while still returning all the items in the original order. The only tricky part is initializing the first ForwardLinks by iterating the root. To keep Traverse clean I moved that to a constructor on ForwardLink.
public static IEnumerable<T> Traverse<T>(
this T root,
Func<T, IEnumerable<T>> childrenSelector) {
var head = new ForwardLink<T>(childrenSelector(root));
if (head.Value == null) yield break; // No items from root iterator
while (head != null)
{
var headValue = head.Value;
var localTail = head;
var second = head.Next;
// Insert new elements immediately behind head.
foreach (var child in childrenSelector(headValue))
localTail = localTail.Append(child);
// Splice on the old tail, if there was one
if (second != null) localTail.Next = second;
// Pop the head
yield return headValue;
head = head.Next;
}
}
public class ForwardLink<T> {
public T Value { get; private set; }
public ForwardLink<T> Next { get; set; }
public ForwardLink(T value) { Value = value; }
public ForwardLink(IEnumerable<T> values) {
bool firstElement = true;
ForwardLink<T> tail = null;
foreach (T item in values)
{
if (firstElement)
{
Value = item;
firstElement = false;
tail = this;
}
else
{
tail = tail.Append(item);
}
}
}
public ForwardLink<T> Append(T value) {
return Next = new ForwardLink<T>(value);
}
}
I propose a different version, without using yield:
public abstract class RecursiveEnumerator : IEnumerator {
public RecursiveEnumerator(ICollection collection) {
this.collection = collection;
this.enumerator = collection.GetEnumerator();
}
protected abstract ICollection GetChildCollection(object item);
public bool MoveNext() {
if (enumerator.Current != null) {
ICollection child_collection = GetChildCollection(enumerator.Current);
if (child_collection != null && child_collection.Count > 0) {
stack.Push(enumerator);
enumerator = child_collection.GetEnumerator();
}
}
while (!enumerator.MoveNext()) {
if (stack.Count == 0) return false;
enumerator = stack.Pop();
}
return true;
}
public virtual void Dispose() { }
public object Current { get { return enumerator.Current; } }
public void Reset() {
stack.Clear();
enumerator = collection.GetEnumerator();
}
private IEnumerator enumerator;
private Stack<IEnumerator> stack = new Stack<IEnumerator>();
private ICollection collection;
}
Usage example
public class RecursiveControlEnumerator : RecursiveEnumerator, IEnumerator {
public RecursiveControlEnumerator(Control.ControlCollection controlCollection)
: base(controlCollection) { }
protected override ICollection GetChildCollection(object c) {
return (c as Control).Controls;
}
}
To expand on my comment, this should work:
public static IEnumerable<Foo> Descendants(this Foo node)
{
yield return node; // return branch nodes
foreach (var child in node.Children())
foreach (var c2 in child.Descendants())
yield return c2; // return leaf nodes
}
That should will return all branch nodes and leaf nodes. If you only want to return leaf nodes, remove the first yield return.
In response to your question, yes it is an algorithmic primitive, because you definitely need to call node.Children(), and you definitely need to call child.Descendants() on each child. I agree that it seems odd having two "foreach" loops, but the second one is actually just continuing the overall enumeration, not iterating the children.
Try this:
private static IEnumerable<T> Descendants<T>(
this IEnumerable<T> children, Func<T, IEnumerable<T>> enumerator)
{
Func<T, IEnumerable<T>> getDescendants =
child => enumerator(child).Descendants(enumerator);
Func<T, IEnumerable<T>> getChildWithDescendants =
child => new[] { child }.Concat(getDescendants(child));
return children.SelectMany(getChildWithDescendants);
}
Or if you prefer the non Linq variant:
private static IEnumerable<T> Descendants<T>(
this IEnumerable<T> children, Func<T, IEnumerable<T>> enumerator)
{
foreach (var child in children)
{
yield return child;
var descendants = enumerator(child).Descendants(enumerator);
foreach (var descendant in descendants)
{
yield return descendant;
}
}
}
And call it like:
root.Children().Descendants(f => f.Children())
I need to search a tree for data that could be anywhere in the tree. How can this be done with linq?
class Program
{
static void Main(string[] args) {
var familyRoot = new Family() {Name = "FamilyRoot"};
var familyB = new Family() {Name = "FamilyB"};
familyRoot.Children.Add(familyB);
var familyC = new Family() {Name = "FamilyC"};
familyB.Children.Add(familyC);
var familyD = new Family() {Name = "FamilyD"};
familyC.Children.Add(familyD);
//There can be from 1 to n levels of families.
//Search all children, grandchildren, great grandchildren etc, for "FamilyD" and return the object.
}
}
public class Family {
public string Name { get; set; }
List<Family> _children = new List<Family>();
public List<Family> Children {
get { return _children; }
}
}
That's an extension to It'sNotALie.s answer.
public static class Linq
{
public static IEnumerable<T> Flatten<T>(this T source, Func<T, IEnumerable<T>> selector)
{
return selector(source).SelectMany(c => Flatten(c, selector))
.Concat(new[] { source });
}
}
Sample test usage:
var result = familyRoot.Flatten(x => x.Children).FirstOrDefault(x => x.Name == "FamilyD");
Returns familyD object.
You can make it work on IEnumerable<T> source too:
public static IEnumerable<T> Flatten<T>(this IEnumerable<T> source, Func<T, IEnumerable<T>> selector)
{
return source.SelectMany(x => Flatten(x, selector))
.Concat(source);
}
Another solution without recursion...
var result = FamilyToEnumerable(familyRoot)
.Where(f => f.Name == "FamilyD");
IEnumerable<Family> FamilyToEnumerable(Family f)
{
Stack<Family> stack = new Stack<Family>();
stack.Push(f);
while (stack.Count > 0)
{
var family = stack.Pop();
yield return family;
foreach (var child in family.Children)
stack.Push(child);
}
}
Simple:
familyRoot.Flatten(f => f.Children);
//you can do whatever you want with that sequence there.
//for example you could use Where on it and find the specific families, etc.
IEnumerable<T> Flatten<T>(this T source, Func<T, IEnumerable<T>> selector)
{
return selector(source).SelectMany(c => Flatten(selector(c), selector))
.Concat(new[]{source});
}
So, the simplest option is to write a function that traverses your hierarchy and produces a single sequence. This then goes at the start of your LINQ operations, e.g.
IEnumerable<T> Flatten<T>(this T source)
{
foreach(var item in source) {
yield item;
foreach(var child in Flatten(item.Children)
yield child;
}
}
To call simply: familyRoot.Flatten().Where(n => n.Name == "Bob");
A slight alternative would give you a way to quickly ignore a whole branch:
IEnumerable<T> Flatten<T>(this T source, Func<T, bool> predicate)
{
foreach(var item in source) {
if (predicate(item)) {
yield item;
foreach(var child in Flatten(item.Children)
yield child;
}
}
Then you could do things like: family.Flatten(n => n.Children.Count > 2).Where(...)
I like Kenneth Bo Christensen's answer using stack, it works great, it is easy to read and it is fast (and doesn't use recursion).
The only unpleasant thing is that it reverses the order of child items (because stack is FIFO). If sort order doesn't matter to you then it's ok.
If it does, sorting can be achieved easily using selector(current).Reverse() in the foreach loop (the rest of the code is the same as in Kenneth's original post)...
public static IEnumerable<T> Flatten<T>(this T source, Func<T, IEnumerable<T>> selector)
{
var stack = new Stack<T>();
stack.Push(source);
while (stack.Count > 0)
{
var current = stack.Pop();
yield return current;
foreach (var child in selector(current).Reverse())
stack.Push(child);
}
}
Well, I guess the way is to go with the technique of working with hierarchical structures:
You need an anchor to make
You need the recursion part
// Anchor
rootFamily.Children.ForEach(childFamily =>
{
if (childFamily.Name.Contains(search))
{
// Your logic here
return;
}
SearchForChildren(childFamily);
});
// Recursion
public void SearchForChildren(Family childFamily)
{
childFamily.Children.ForEach(_childFamily =>
{
if (_childFamily.Name.Contains(search))
{
// Your logic here
return;
}
SearchForChildren(_childFamily);
});
}
I have tried two of the suggested codes and made the code a bit more clear:
public static IEnumerable<T> Flatten1<T>(this T source, Func<T, IEnumerable<T>> selector)
{
return selector(source).SelectMany(c => Flatten1(c, selector)).Concat(new[] { source });
}
public static IEnumerable<T> Flatten2<T>(this T source, Func<T, IEnumerable<T>> selector)
{
var stack = new Stack<T>();
stack.Push(source);
while (stack.Count > 0)
{
var current = stack.Pop();
yield return current;
foreach (var child in selector(current))
stack.Push(child);
}
}
Flatten2() seems to be a little bit faster but its a close run.
Some further variants on the answers of It'sNotALie., MarcinJuraszek and DamienG.
First, the former two give a counterintuitive ordering. To get a nice tree-traversal ordering to the results, just invert the concatenation (put the "source" first).
Second, if you are working with an expensive source like EF, and you want to limit entire branches, Damien's suggestion that you inject the predicate is a good one and can still be done with Linq.
Finally, for an expensive source it may also be good to pre-select the fields of interest from each node with an injected selector.
Putting all these together:
public static IEnumerable<R> Flatten<T,R>(this T source, Func<T, IEnumerable<T>> children
, Func<T, R> selector
, Func<T, bool> branchpredicate = null
) {
if (children == null) throw new ArgumentNullException("children");
if (selector == null) throw new ArgumentNullException("selector");
var pred = branchpredicate ?? (src => true);
if (children(source) == null) return new[] { selector(source) };
return new[] { selector(source) }
.Concat(children(source)
.Where(pred)
.SelectMany(c => Flatten(c, children, selector, pred)));
}
Not sure how to call it, but say you have a class that looks like this:
class Person
{
public string Name;
public IEnumerable<Person> Friends;
}
You then have a person and you want to "unroll" this structure recursively so you end up with a single list of all people without duplicates.
How would you do this? I have already made something that seems to be working, but I am curious to see how others would do it and especially if there is something built-in to Linq you can use in a clever way to solve this little problem :)
Here is my solution:
public static IEnumerable<T> SelectRecursive<T>(this IEnumerable<T> subjects, Func<T, IEnumerable<T>> selector)
{
// Stop if subjects are null or empty
if(subjects == null)
yield break;
// For each subject
foreach(var subject in subjects)
{
// Yield it
yield return subject;
// Then yield all its decendants
foreach (var decendant in SelectRecursive(selector(subject), selector))
yield return decendant;
}
}
Would be used something like this:
var people = somePerson.SelectRecursive(x => x.Friends);
I don't believe there's anything built into LINQ to do this.
There's a problem with doing it recursively like this - you end up creating a large number of iterators. This can be quite inefficient if the tree is deep. Wes Dyer and Eric Lippert have both blogged about this.
You can remove this inefficiency by removing the direct recursion. For example:
public static IEnumerable<T> SelectRecursive<T>(this IEnumerable<T> subjects,
Func<T, IEnumerable<T>> selector)
{
if (subjects == null)
{
yield break;
}
Queue<T> stillToProcess = new Queue<T>(subjects);
while (stillToProcess.Count > 0)
{
T item = stillToProcess.Dequeue();
yield return item;
foreach (T child in selector(item))
{
stillToProcess.Enqueue(child);
}
}
}
This will also change the iteration order - it becomes breadth-first instead of depth-first; rewriting it to still be depth-first is tricky. I've also changed it to not use Any() - this revised version won't evaluate any sequence more than once, which can be handy in some scenarios. This does have one problem, mind you - it will take more memory, due to the queuing. We could probably alleviate this by storing a queue of iterators instead of items, but I'm not sure offhand... it would certainly be more complicated.
One point to note (also noted by ChrisW while I was looking up the blog posts :) - if you have any cycles in your friends list (i.e. if A has B, and B has A) then you'll recurse forever.
I found this question as I was looking for and thinking about a similar solution - in my case creating an efficient IEnumerable<Control> for ASP.NET UI controls. The recursive yield I had is fast but I knew that could have extra cost, since the deeper the control structure the longer it could take. Now I know this is O(n log n).
The solution given here provides some answer but, as discussed in the comments, it does change the order (which the OP did not care about). I realized that to preserve the order as given by the OP and as I needed, neither a simple Queue (as Jon used) nor Stack would work since all the parent objects would be yielded first and then any children after them (or vice-versa).
To resolve this and preserve the order I realized the solution would simply be to put the Enumerator itself on a Stack. To use the OPs original question it would look like this:
public static IEnumerable<T> SelectRecursive<T>(this IEnumerable<T> subjects, Func<T, IEnumerable<T>> selector)
{
if (subjects == null)
yield break;
var stack = new Stack<IEnumerator<T>>();
stack.Push(subjects.GetEnumerator());
while (stack.Count > 0)
{
var en = stack.Peek();
if (en.MoveNext())
{
var subject = en.Current;
yield return subject;
stack.Push(selector(subject).GetEnumerator());
}
else
{
stack.Pop().Dispose();
}
}
}
I use stack.Peek here to keep from having to push the same enumerator back on to the stack as this is likely to be the more frequent operation, expecting that enumerator to provide more than one item.
This creates the same number of enumerators as in the recursive version but will likely be fewer new objects than putting all the subjects in a queue or stack and continuing to add any descendant subjects. This is O(n) time as each enumerator stands on its own (in the recursive version an implicit call to one MoveNext executes MoveNext on the child enumerators to the current depth in the recursion stack).
You could use a non-recursive method like this as well:
HashSet<Person> GatherAll (Person p) {
Stack<Person> todo = new Stack<Person> ();
HashSet<Person> results = new HashSet<Person> ();
todo.Add (p); results.Add (p);
while (todo.Count > 0) {
Person p = todo.Pop ();
foreach (Person f in p.Friends)
if (results.Add (f)) todo.Add (f);
}
return results;
}
This should handle cycles properly as well. I am starting with a single person, but you could easily expand this to start with a list of persons.
Here's an implementation that:
Does a depth first recursive select,
Doesn't require double iteration of the child collections,
Doesn't use intermediate collections for the selected elements,
Doesn't handle cycles,
Can do it backwards.
public static IEnumerable<T> SelectRecursive<T>(this IEnumerable<T> rootItems, Func<T, IEnumerable<T>> selector)
{
return new RecursiveEnumerable<T>(rootItems, selector, false);
}
public static IEnumerable<T> SelectRecursiveReverse<T>(this IEnumerable<T> rootItems, Func<T, IEnumerable<T>> selector)
{
return new RecursiveEnumerable<T>(rootItems, selector, true);
}
class RecursiveEnumerable<T> : IEnumerable<T>
{
public RecursiveEnumerable(IEnumerable<T> rootItems, Func<T, IEnumerable<T>> selector, bool reverse)
{
_rootItems = rootItems;
_selector = selector;
_reverse = reverse;
}
IEnumerable<T> _rootItems;
Func<T, IEnumerable<T>> _selector;
bool _reverse;
public IEnumerator<T> GetEnumerator()
{
return new Enumerator(this);
}
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
class Enumerator : IEnumerator<T>
{
public Enumerator(RecursiveEnumerable<T> owner)
{
_owner = owner;
Reset();
}
RecursiveEnumerable<T> _owner;
T _current;
Stack<IEnumerator<T>> _stack = new Stack<IEnumerator<T>>();
public T Current
{
get
{
if (_stack == null || _stack.Count == 0)
throw new InvalidOperationException();
return _current;
}
}
public void Dispose()
{
_current = default(T);
if (_stack != null)
{
while (_stack.Count > 0)
{
_stack.Pop().Dispose();
}
_stack = null;
}
}
object System.Collections.IEnumerator.Current
{
get { return Current; }
}
public bool MoveNext()
{
if (_owner._reverse)
return MoveReverse();
else
return MoveForward();
}
public bool MoveForward()
{
// First time?
if (_stack == null)
{
// Setup stack
_stack = new Stack<IEnumerator<T>>();
// Start with the root items
_stack.Push(_owner._rootItems.GetEnumerator());
}
// Process enumerators on the stack
while (_stack.Count > 0)
{
// Get the current one
var se = _stack.Peek();
// Next please...
if (se.MoveNext())
{
// Store it
_current = se.Current;
// Get child items
var childItems = _owner._selector(_current);
if (childItems != null)
{
_stack.Push(childItems.GetEnumerator());
}
return true;
}
// Finished with the enumerator
se.Dispose();
_stack.Pop();
}
// Finished!
return false;
}
public bool MoveReverse()
{
// First time?
if (_stack == null)
{
// Setup stack
_stack = new Stack<IEnumerator<T>>();
// Start with the root items
_stack.Push(_owner._rootItems.Reverse().GetEnumerator());
}
// Process enumerators on the stack
while (_stack.Count > 0)
{
// Get the current one
var se = _stack.Peek();
// Next please...
if (se.MoveNext())
{
// Get child items
var childItems = _owner._selector(se.Current);
if (childItems != null)
{
_stack.Push(childItems.Reverse().GetEnumerator());
continue;
}
// Store it
_current = se.Current;
return true;
}
// Finished with the enumerator
se.Dispose();
_stack.Pop();
if (_stack.Count > 0)
{
_current = _stack.Peek().Current;
return true;
}
}
// Finished!
return false;
}
public void Reset()
{
Dispose();
}
}
}
use the Aggregate extension...
List<Person> persons = GetPersons();
List<Person> result = new List<Person>();
persons.Aggregate(result,SomeFunc);
private static List<Person> SomeFunc(List<Person> arg1,Person arg2)
{
arg1.Add(arg2)
arg1.AddRange(arg2.Persons);
return arg1;
}
Recursion is always fun. Perhaps you could simplify your code to:
public static IEnumerable<T> SelectRecursive<T>(this IEnumerable<T> subjects, Func<T, IEnumerable<T>> selector) {
// Stop if subjects are null or empty
if (subjects == null || !subjects.Any())
return Enumerable.Empty<T>();
// Gather a list of all (selected) child elements of all subjects
var subjectChildren = subjects.SelectMany(selector);
// Jump into the recursion for each of the child elements
var recursiveChildren = SelectRecursive(subjectChildren, selector);
// Combine the subjects with all of their (recursive child elements).
// The union will remove any direct parent-child duplicates.
// Endless loops due to circular references are however still possible.
return subjects.Union(recursiveChildren);
}
It will result in less duplicates than your original code. However their might still be duplicates causing an endless loop, the union will only prevent direct parent(s)-child(s) duplicates.
And the order of the items will be different from yours :)
Edit: Changed the last line of code to three statements and added a bit more documentation.
While its great to have IEnumerable when there might be a lot of data, its worth remembering the classic approach of recursively adding to a list.
That can be as simple as this (I've left out selector; just demonstrating recursively adding to an output list):
class Node
{
public readonly List<Node> Children = new List<Node>();
public List<Node> Flatten()
{
var all = new List<Node>();
Flatten(ref all);
return all;
}
public void Flatten(List<Node> all)
{
all.Add(this);
foreach (var child in Children)
child.Flatten(all);
}
}
usage:
Node rootNode = ...;
...
var all = rootNode.Flatten();