How to implement Task.WhenAny() with a predicate - c#

I want to execute several asynchronous tasks concurrently. Each task will run an HTTP request that can either complete successfully or throw an exception. I need to await until the first task completes successfully, or until all the tasks have failed.
How can I implement an overload of the Task.WhenAny method that accepts a predicate, so that I can exclude the non-successfully completed tasks?

Wait for any task and return the task if the condition is met. Otherwise wait again for the other tasks until there is no more task to wait for.
public static async Task<Task> WhenAny( IEnumerable<Task> tasks, Predicate<Task> condition )
{
var tasklist = tasks.ToList();
while ( tasklist.Count > 0 )
{
var task = await Task.WhenAny( tasklist );
if ( condition( task ) )
return task;
tasklist.Remove( task );
}
return null;
}
simple check for that
var tasks = new List<Task> {
Task.FromException( new Exception() ),
Task.FromException( new Exception() ),
Task.FromException( new Exception() ),
Task.CompletedTask, };
var completedTask = WhenAny( tasks, t => t.Status == TaskStatus.RanToCompletion ).Result;
if ( tasks.IndexOf( completedTask ) != 3 )
throw new Exception( "not expected" );

public static Task<T> GetFirstResult<T>(
ICollection<Func<CancellationToken, Task<T>>> taskFactories,
Predicate<T> predicate) where T : class
{
var tcs = new TaskCompletionSource<T>();
var cts = new CancellationTokenSource();
int completedCount = 0;
// in case you have a lot of tasks you might need to throttle them
//(e.g. so you don't try to send 99999999 requests at the same time)
// see: http://stackoverflow.com/a/25877042/67824
foreach (var taskFactory in taskFactories)
{
taskFactory(cts.Token).ContinueWith(t =>
{
if (t.Exception != null)
{
Console.WriteLine($"Task completed with exception: {t.Exception}");
}
else if (predicate(t.Result))
{
cts.Cancel();
tcs.TrySetResult(t.Result);
}
if (Interlocked.Increment(ref completedCount) == taskFactories.Count)
{
tcs.SetException(new InvalidOperationException("All tasks failed"));
}
}, cts.Token);
}
return tcs.Task;
}
Sample usage:
using System.Net.Http;
var client = new HttpClient();
var response = await GetFirstResult(
new Func<CancellationToken, Task<HttpResponseMessage>>[]
{
ct => client.GetAsync("http://microsoft123456.com", ct),
ct => client.GetAsync("http://microsoft123456.com", ct),
ct => client.GetAsync("http://microsoft123456.com", ct),
ct => client.GetAsync("http://microsoft123456.com", ct),
ct => client.GetAsync("http://microsoft123456.com", ct),
ct => client.GetAsync("http://microsoft123456.com", ct),
ct => client.GetAsync("http://microsoft123456.com", ct),
ct => client.GetAsync("http://microsoft.com", ct),
ct => client.GetAsync("http://microsoft123456.com", ct),
ct => client.GetAsync("http://microsoft123456.com", ct),
},
rm => rm.IsSuccessStatusCode);
Console.WriteLine($"Successful response: {response}");

public static Task<Task<T>> WhenFirst<T>(IEnumerable<Task<T>> tasks, Func<Task<T>, bool> predicate)
{
if (tasks == null) throw new ArgumentNullException(nameof(tasks));
if (predicate == null) throw new ArgumentNullException(nameof(predicate));
var tasksArray = (tasks as IReadOnlyList<Task<T>>) ?? tasks.ToArray();
if (tasksArray.Count == 0) throw new ArgumentException("Empty task list", nameof(tasks));
if (tasksArray.Any(t => t == null)) throw new ArgumentException("Tasks contains a null reference", nameof(tasks));
var tcs = new TaskCompletionSource<Task<T>>();
var count = tasksArray.Count;
Action<Task<T>> continuation = t =>
{
if (predicate(t))
{
tcs.TrySetResult(t);
}
if (Interlocked.Decrement(ref count) == 0)
{
tcs.TrySetResult(null);
}
};
foreach (var task in tasksArray)
{
task.ContinueWith(continuation);
}
return tcs.Task;
}
Sample usage:
var task = await WhenFirst(tasks, t => t.Status == TaskStatus.RanToCompletion);
if (task != null)
var value = await task;
Note that this doesn't propagate exceptions of failed tasks (just as WhenAny doesn't).
You can also create a version of this for the non-generic Task.

Here is an attempted improvement of the excellent Eli Arbel's answer. These are the improved points:
An exception in the predicate is propagated as a fault of the returned task.
The predicate is not called after a task has been accepted as the result.
The predicate is executed in the original SynchronizationContext. This makes it possible to access UI elements (if the WhenFirst method is called from a UI thread)
The source IEnumerable<Task<T>> is enumerated directly, without being converted to an array first.
public static Task<Task<T>> WhenFirst<T>(IEnumerable<Task<T>> tasks,
Func<Task<T>, bool> predicate)
{
if (tasks == null) throw new ArgumentNullException(nameof(tasks));
if (predicate == null) throw new ArgumentNullException(nameof(predicate));
var tcs = new TaskCompletionSource<Task<T>>(
TaskCreationOptions.RunContinuationsAsynchronously);
var pendingCount = 1; // The initial 1 represents the enumeration itself
foreach (var task in tasks)
{
if (task == null) throw new ArgumentException($"The {nameof(tasks)}" +
" argument included a null value.", nameof(tasks));
Interlocked.Increment(ref pendingCount);
HandleTaskCompletion(task);
}
if (Interlocked.Decrement(ref pendingCount) == 0) tcs.TrySetResult(null);
return tcs.Task;
async void HandleTaskCompletion(Task<T> task)
{
try
{
await task; // Continue on the captured context
}
catch { } // Ignore exception
if (tcs.Task.IsCompleted) return;
try
{
if (predicate(task))
tcs.TrySetResult(task);
else
if (Interlocked.Decrement(ref pendingCount) == 0)
tcs.TrySetResult(null);
}
catch (Exception ex)
{
tcs.TrySetException(ex);
}
}
}

Another way of doing this, very similar to Sir Rufo's answer, but using AsyncEnumerable and Ix.NET
Implement a little helper method to stream any task as soon as it's completed:
static IAsyncEnumerable<Task<T>> WhenCompleted<T>(IEnumerable<Task<T>> source) =>
AsyncEnumerable.Create(_ =>
{
var tasks = source.ToList();
Task<T> current = null;
return AsyncEnumerator.Create(
async () => tasks.Any() && tasks.Remove(current = await Task.WhenAny(tasks)),
() => current,
async () => { });
});
}
One can then process the tasks in completion order, e.g. returning the first matching one as requested:
await WhenCompleted(tasks).FirstOrDefault(t => t.Status == TaskStatus.RanToCompletion)

Just wanted to add on some of the answers #Peebo and #SirRufo that are using List.Remove (because I can't comment yet)
I would consider using:
var tasks = source.ToHashSet();
instead of:
var tasks = source.ToList();
so removing would be more efficient

Related

How to use Task.Factory.FromAsync with ldap library

I found this class online:
public class AsyncSearcher
{
LdapConnection _connect;
public AsyncSearcher(LdapConnection connection)
{
this._connect = connection;
this._connect.AutoBind = true; //will bind on first search
}
public void BeginPagedSearch(
string baseDN,
string filter,
string[] attribs,
int pageSize,
Action<SearchResponse> page,
Action<Exception> completed
)
{
if (page == null)
throw new ArgumentNullException("page");
AsyncOperation asyncOp = AsyncOperationManager.CreateOperation(null);
Action<Exception> done = e =>
{
if (completed != null) asyncOp.Post(delegate
{
completed(e);
}, null);
};
SearchRequest request = new SearchRequest(
baseDN,
filter,
System.DirectoryServices.Protocols.SearchScope.Subtree,
attribs
);
PageResultRequestControl prc = new PageResultRequestControl(pageSize);
//add the paging control
request.Controls.Add(prc);
AsyncCallback rc = null;
rc = readResult =>
{
try
{
var response = (SearchResponse)_connect.EndSendRequest(readResult);
//let current thread handle results
asyncOp.Post(delegate
{
page(response);
}, null);
var cookie = response.Controls
.Where(c => c is PageResultResponseControl)
.Select(s => ((PageResultResponseControl)s).Cookie)
.Single();
if (cookie != null && cookie.Length != 0)
{
prc.Cookie = cookie;
_connect.BeginSendRequest(
request,
PartialResultProcessing.NoPartialResultSupport,
rc,
null
);
}
else done(null); //signal complete
}
catch (Exception ex) { done(ex); }
};
//kick off async
try
{
_connect.BeginSendRequest(
request,
PartialResultProcessing.NoPartialResultSupport,
rc,
null
);
}
catch (Exception ex) { done(ex); }
}
}
I am basically trying to convert the below code which writes to the console to return data from Task.Factory.FromAsync, so that I can use the data elsewhere.
using (LdapConnection connection = CreateConnection(servername))
{
AsyncSearcher searcher = new AsyncSearcher(connection);
searcher.BeginPagedSearch(
baseDN,
"(sn=Dunn)",
null,
100,
f => //runs per page
{
foreach (var item in f.Entries)
{
var entry = item as SearchResultEntry;
if (entry != null)
{
Console.WriteLine(entry.DistinguishedName);
}
}
},
c => //runs on error or when done
{
if (c != null) Console.WriteLine(c.ToString());
Console.WriteLine("Done");
_resetEvent.Set();
}
);
_resetEvent.WaitOne();
}
I tried this but get the following syntax errors:
LdapConnection connection1 = CreateConnection(servername);
AsyncSearcher1 searcher = new AsyncSearcher1(connection1);
async Task<SearchResultEntryCollection> RootDSE(LdapConnection connection)
{
return await Task.Factory.FromAsync(,
() =>
{
return searcher.BeginPagedSearch(baseDN, "(cn=a*)", null, 100, f => { return f.Entries; }, c => { _resetEvent.Set(); });
}
);
}
_resetEvent.WaitOne();
The APM ("Asynchronous Programming Model") style of asynchronous code uses Begin and End method pairs along with IAsyncResult, following a specific pattern.
The Task.Factory.FromAsync method is designed to wrap APM method pairs into a modern TAP ("Task-based Asynchronous Programming") style of asynchronous code.
However, FromAsync requires the methods to follow the APM pattern exactly, and BeginPagedSearch does not follow the APM pattern. So you will need to use TaskCompletionSource<T> directly. TaskCompletionSource<T> can be used to convert any existing asynchronous pattern to TAP as long as it has a single result.
The method you're trying to wrap has multiple callbacks, so it can't be mapped to TAP at all. If you want to collect all result sets and return a list of them, then you can use TaskCompletionSource<T> for that. Otherwise, you'll want to use something like IAsyncEnumerable<T>, which would require writing your own implementation of BeginPagedSearch.

Error handling for Tasks inside Task.WhenAll

I'm trying to create a web-scraper that queries a lot of urls in parallel and waits for their responses using Task.WhenAll(). However if one of the Tasks are unsuccessful, WhenAll fails. I am expecting many of the Tasks to return a 404 and wish to handle or ignore those. For example:
string urls = Enumerable.Range(1, 1000).Select(i => "https://somewebsite.com/" + i));
List<Task<string>> tasks = new List<Task<string>>();
foreach (string url in urls)
{
tasks.Add(Task.Run(() => {
try
{
return (new HttpClient()).GetStringAsync(url);
}
catch (HttpRequestException)
{
return Task.FromResult<string>("");
}
}));
}
var responseStrings = await Task.WhenAll(tasks);
This never hits the catch statement, and WhenAll fails at the first 404. How can I get WhenAll to ignore exceptions and just return the Tasks that completed successfully? Better yet, could it be done somewhere in the code below?
var tasks = Enumerable.Range(1, 1000).Select(i => (new HttpClient()).GetStringAsync("https://somewebsite.com/" + i))));
var responseStrings = await Task.WhenAll(tasks);
Thanks for your help.
You need to use await to observe the exception:
var tasks = Enumerable.Range(1, 1000).Select(i => TryGetStringAsync("https://somewebsite.com/" + i));
var responseStrings = await Task.WhenAll(tasks);
var validResponses = responseStrings.Where(x => x != null);
private async Task TryGetStringAsync(string url)
{
try
{
return await httpClient.GetStringAsync(url);
}
catch (HttpRequestException)
{
return null;
}
}

How to avoid calling onComplete before on next is finished?

Let's say I have a data service class that fetches the data batch by batch and its chunks to the subscribers.
public class DataService {
public IObservable<IList<T>> QuerySegmentedObservable<T>(string tableName) where T : TableEntity, new(){
return Observable.Create<IList<T>>(async (observer, token) =>{
TableContinuationToken continuationToken = null;
do{
var currentSegment = CallData();
observer.OnNext(currentSegment.Results);
continuationToken = currentSegment.ContinuationToken;
} while (continuationToken != null);
observer.OnCompleted();
}
}
}
I am subscribing this observable as below.
public async Task<bool> MyMethod()
{
var tcs = new TaskCompletionSource<bool>();
var observable = _dataService.QuerySegmentedObservable<TSource>(_sourceTableName);
var dataCount = 0;
_databaseService.OpenConnection();
observable.Subscribe(async data =>
{
await _databaseService.DoSomething(data);
dataCount += data.Count;
Console.WriteLine($"Processing - {dataCount}");
},
err =>
{
Console.WriteLine($"Error - {err.Message}");
tcs.SetResult(false);
},
() =>
{
_databaseService.CloseConnection();
Console.WriteLine($"Finished");
tcs.SetResult(true);
}
);
return await tcs.Task;
}
The problem is that OnComplete() is called before the last OnNext() is finished. So, I ended up closing the connection before finishing the task that I am doing in Subscribe();
Is there any way to fix it? Thanks.
Rx does support async/await within operators. You're using it though within a subscription. So (hopefully) you can change your code to something like this:
public async Task<bool> MyMethod()
{
var tcs = new TaskCompletionSource<bool>();
_databaseService.OpenConnection();
var dataCount = 0;
_dataService.QuerySegmentedObservable<TSource>(_sourceTableName)
.SelectMany(async data =>
{
await _databaseService.DoSomething(data);
return data;
})
//.Finally(() => _databaseService.CloseConnection()) //This would be called on OnComplete and OnError, just like try-finally
.Subscribe(data =>
{
dataCount += data.Count;
Console.WriteLine($"Processing - {dataCount}");
},
err =>
{
Console.WriteLine($"Error - {err.Message}");
tcs.SetResult(false);
},
() =>
{
_databaseService.CloseConnection(); //Maybe move this to a Finally call?
Console.WriteLine($"Finished");
tcs.SetResult(true);
}
);
return await tcs.Task;
}
I can't really test it, so I hope that sets you on the right path. If you need more help, then please post a better MCVE.

Why do I have to use (async method).result instead of await (async method)?

I am starting 2 channels in the mediaservices azure portal.
Starting a channel takes a long time to complete, about 25-30 seconds per channel. Hence, multithreading :)
However, the following is not clear to me:
I have 2 methods:
public async Task<bool> StartAsync(string programName, CancellationToken token = default(CancellationToken))
{
var workerThreads = new List<Thread>();
var results = new List<bool>();
foreach (var azureProgram in _accounts.GetPrograms(programName))
{
var thread = new Thread(() =>
{
var result = StartChannelAsync(azureProgram).Result;
lock (results)
{
results.Add(result);
}
});
workerThreads.Add(thread);
thread.Start();
}
foreach (var thread in workerThreads)
{
thread.Join();
}
return results.All(r => r);
}
and
private async Task<bool> StartChannelAsync(IProgram azureProgram)
{
var state = _channelFactory.ConvertToState(azureProgram.Channel.State);
if (state == State.Running)
{
return true;
}
if (state.IsTransitioning())
{
return false;
}
await azureProgram.Channel.StartAsync();
return true;
}
in the first method I use
var result = StartChannelAsync(azureProgram).Result;
In this case everything works fine. But if I use
var result = await StartChannelAsync(azureProgram);
Executing is not awaited and my results has zero entries.
What am I missing here?
And is this a correct way?
Any comments on the code is appreciated. I am not a multithreading king ;)
Cheers!
Don't span new Thread instances to execute tasks in parallel, instead use Task.WhenAll:
public async Task<bool> StartAsync(string programName, CancellationToken token = default(CancellationToken))
{
// Create a task for each program and fire them "at the same time"
Task<bool>[] startingChannels = _accounts.GetPrograms(programName))
.Select(n => StartChannelAsync(n))
.ToArray();
// Create a task that will be completed when all the supplied tasks are done
bool[] results = await Task.WhenAll(startingChannels);
return results.All(r => r);
}
Note: I see that you're passing a CancellationToken to your StartAsync method, but you're not actually using it. Consider passing it as an argument to StartChannelAsync, and then use it when calling azureProgram.Channel.StartAsync
If you love one-liners:
public async Task<bool> StartAsync(string programName, CancellationToken token = default(CancellationToken))
{
return (await Task.WhenAll(_accounts.GetPrograms(programName)
.Select(p => StartChannelAsync(p))
.ToArray())).All(r => r);
}

How to - Multiple Async tasks with timeout and cancellation

I would like to fire several tasks while setting a timeout on them. The idea is to gather the results from the tasks that beat the clock, and cancel (or even just ignore) the other tasks.
I tried using extension methods WithCancellation as explained here, however throwing an exception caused WhenAll to return and supply no results.
Here's what I tried, but I'm opened to other directions as well (note however that I need to use await rather than Task.Run since I need the httpContext in the Tasks):
var cts = new CancellationTokenSource(TimeSpan.FromSeconds(3));
IEnumerable<Task<MyResults>> tasks =
from url in urls
select taskAsync(url).WithCancellation(cts.Token);
Task<MyResults>[] excutedTasks = null;
MyResults[] res = null;
try
{
// Execute the query and start the searches:
excutedTasks = tasks.ToArray();
res = await Task.WhenAll(excutedTasks);
}
catch (Exception exc)
{
if (excutedTasks != null)
{
foreach (Task<MyResults> faulted in excutedTasks.Where(t => t.IsFaulted))
{
// work with faulted and faulted.Exception
}
}
}
// work with res
EDIT:
Following #Servy's answer below, this is the implementation I went with:
var cts = new CancellationTokenSource(TimeSpan.FromSeconds(3));
IEnumerable<Task<MyResults>> tasks =
from url in urls
select taskAsync(url).WithCancellation(cts.Token);
// Execute the query and start the searches:
Task<MyResults>[] excutedTasks = tasks.ToArray();
try
{
await Task.WhenAll(excutedTasks);
}
catch (OperationCanceledException)
{
// Do nothing - we expect this if a timeout has occurred
}
IEnumerable<Task<MyResults>> completedTasks = excutedTasks.Where(t => t.Status == TaskStatus.RanToCompletion);
var results = new List<MyResults>();
completedTasks.ForEach(async t => results.Add(await t));
If any of the tasks fail to complete you are correct that WhenAll doesn't return the results of any that did complete, it just wraps an aggregate exception of all of the failures. Fortunately, you have the original collection of tasks, so you can get the results that completed successfully from there.
var completedTasks = excutedTasks.Where(t => t.Status == TaskStatus.RanToCompletion);
Just use that instead of res.
I tried you code and it worked just fine, except the cancelled tasks are in not in a Faulted state, but rather in the Cancelled. So if you want to process the cancelled tasks use t.IsCanceled instead. The non cancelled tasks ran to completion. Here is the code I used:
public static async Task MainAsync()
{
var urls = new List<string> {"url1", "url2", "url3", "url4", "url5", "url6"};
var cts = new CancellationTokenSource(TimeSpan.FromSeconds(3));
IEnumerable<Task<MyResults>> tasks =
from url in urls
select taskAsync(url).WithCancellation(cts.Token);
Task<MyResults>[] excutedTasks = null;
MyResults[] res = null;
try
{
// Execute the query and start the searches:
excutedTasks = tasks.ToArray();
res = await Task.WhenAll(excutedTasks);
}
catch (Exception exc)
{
if (excutedTasks != null)
{
foreach (Task<MyResults> faulted in excutedTasks.Where(t => t.IsFaulted))
{
// work with faulted and faulted.Exception
}
}
}
}
public static async Task<MyResults> taskAsync(string url)
{
Console.WriteLine("Start " + url);
var random = new Random();
var delay = random.Next(10);
await Task.Delay(TimeSpan.FromSeconds(delay));
Console.WriteLine("End " + url);
return new MyResults();
}
private static void Main(string[] args)
{
MainAsync().Wait();
}

Categories