Async recipes: after first (matching) task, cancel rest

Spread the love

When it comes to the Task-based Asynchronous Pattern, there are easy patterns to run tasks sequentially (await in a loop) and all in parallel (add to list, await WhenAll). Things get slightly more complicated when the execution model doesn’t fit neatly into one of these categories, and even more so when cancellation is involved. Here is an example, derived from a real-life use case I saw recently:

  • Start with N async functions returning values of type T
  • Execute all N functions in parallel
  • Return the first T result matching predicate P
  • Cancel any tasks that are still pending

The .NET sample Cancel Remaining Async Tasks after One Is Complete almost achieves the basics of what we want, but it has some undesirable qualities. First, it is rather UI-centric, as it mixes the library logic in with the UI flow. Second (mostly due to the UI entanglement) it uses async void — which you should avoid! Third, it does not actually wait for the canceled tasks to finish. In the example this is not a big deal, but there are many situations where returning before awaiting the dangling tasks could result in tricky race conditions (e.g. if the async methods are still accessing shared state, but have moved beyond the point of being cooperatively canceled).

Instead, let’s attempt to fulfill the above use case with some TDD’d async code. We can start with the simplest meaningful case, running a single task which completes synchronously and whose result matches the predicate (commit link). At this point, we have decided on the basic direction — use an extension method which operates on a sequence of Funcs, like so:

namespace TaskSample.Extensions
{
    public static class FuncTaskExtensions 
    { 
        public static Task<T> FirstAsync<T>(this IEnumerable<Func<CancellationToken, Task<T>>> funcs, Predicate<T> pred)
        {
            return funcs.First()(CancellationToken.None);
        } 
    }
}

The implementation so far is bogus, just enough to pass the test.

As it turns out, this implementation is good enough to pass another test where we provide two functions with the first completing sync and matching the predicate (commit link). It is still not that interesting, especially since we haven’t even used the predicate yet. The next test forces the issue by passing the matching result in the second of two functions (commit link). At this point, the code looks like this:

        public static async Task<T> FirstAsync<T>(this IEnumerable<Func<CancellationToken, Task<T>>> funcs, Predicate<T> pred)
        {
            foreach (Func<CancellationToken, Task<T>> func in funcs)
            {
                T result = await func(CancellationToken.None);
                if (pred(result))
                {
                    return result;
                }
            }

            return default(T);
        }

Up to this point, there was no need to cancel, so let’s push that boundary. We can construct a test case with two functions where the first will never complete and the second will synchronously return a matching value (commit link). The code has gotten significantly more complex:

        public static async Task<T> FirstAsync<T>(this IEnumerable<Func<CancellationToken, Task<T>>> funcs, Predicate<T> pred)
        {
            var tasks = new List<Task<T>>();
            T firstResult = default(T);
            using (CancellationTokenSource cts = new CancellationTokenSource())
            {
                foreach (Func<CancellationToken, Task<T>> func in funcs)
                {
                    if (cts.IsCancellationRequested)
                    {
                        break;
                    }

                    Func<CancellationToken, Task<T>> match = async t =>
                    {
                        T result = await func(t);
                        if (pred(result))
                        {
                            firstResult = result;
                            cts.Cancel();
                        }

                        return result;
                    };

                    Task<T> task = match(CancellationToken.None);
                    tasks.Add(task);
                }

                await Task.WhenAny(tasks);
                return firstResult;
            }
        }

It is questionable whether all of this code is actually needed at this point, but it does pass the test. The implementation is stable enough to pass another test where we have a single function that returns asynchronously (commit link). However, it cannot handle a case where a function throws instead of returning (commit link). This revised version does the trick:

        public static async Task<T> FirstAsync<T>(this IEnumerable<Func<CancellationToken, Task<T>>> funcs, Predicate<T> pred)
        {
            var tasks = new List<Task<T>>();
            Tuple<T> firstResult = null;
            using (CancellationTokenSource cts = new CancellationTokenSource())
            {
                foreach (Func<CancellationToken, Task<T>> func in funcs)
                {
                    if (cts.IsCancellationRequested)
                    {
                        break;
                    }

                    Func<CancellationToken, Task<T>> match = async t =>
                    {
                        T result = await func(t);
                        if (pred(result))
                        {
                            firstResult = Tuple.Create(result);
                            cts.Cancel();
                        }

                        return result;
                    };

                    Task<T> task = match(CancellationToken.None);
                    tasks.Add(task);
                }

                await Task.WhenAny(tasks);
                if (firstResult == null)
                {
                    throw new InvalidOperationException("No matching result.");
                }

                return firstResult.Item1;
            }
        }

The only difference is that we need to keep track of whether a result was generated or not. A single element tuple is a handy trick here. We know if the tuple itself is null, then nothing was returned — as opposed to the Item1 being null, which means the function explicitly returned null. Again, this implementation is good enough to pass a similar case where we throw asynchronously instead of synchronously (commit link). It also passes a case where neither function returns a matching result (commit link). After some minor test refactoring (commit link), we can move to the next case which exposes a relatively major logic error. This time we’re trying three async functions where only the last one matches the predicate (commit link). The corrected implementation now looks like this:

        public static async Task<T> FirstAsync<T>(this IEnumerable<Func<CancellationToken, Task<T>>> funcs, Predicate<T> pred)
        {
            var tasks = new List<Task<T>>();
            Tuple<T> firstResult = null;
            using (CancellationTokenSource cts = new CancellationTokenSource())
            {
                foreach (Func<CancellationToken, Task<T>> func in funcs)
                {
                    if (cts.IsCancellationRequested)
                    {
                        break;
                    }

                    Func<CancellationToken, Task<T>> match = async t =>
                    {
                        T result = default(T);
                        try
                        {
                            result = await func(t);
                            if (pred(result))
                            {
                                firstResult = Tuple.Create(result);
                                cts.Cancel();
                            }
                        }
                        catch (Exception)
                        {
                        }

                        return result;
                    };

                    Task<T> task = match(cts.Token);
                    tasks.Add(task);
                }

                await Task.WhenAll(tasks);
                if (firstResult == null)
                {
                    throw new InvalidOperationException("No matching result.");
                }

                return firstResult.Item1;
            }
        }

The mistake was subtle — using WhenAny instead of WhenAll, causing the operation to complete too soon and fail inappropriately. As a side effect of WhenAll, we needed to do better exception handling as well. Now to refactor the implementation code, extracting some of the logic into its own class (commit link):

        public static async Task<T> FirstAsync<T>(this IEnumerable<Func<CancellationToken, Task<T>>> funcs, Predicate<T> pred)
        {
            var tasks = new List<Task<T>>();
            using (MatchFunc<T> match = new MatchFunc<T>(pred))
            {
                foreach (Func<CancellationToken, Task<T>> func in funcs)
                {
                    if (match.Canceled)
                    {
                        break;
                    }

                    Task<T> task = match.RunAsync(func);
                    tasks.Add(task);
                }

                await Task.WhenAll(tasks);
                return match.Result;
            }
        }

        private sealed class MatchFunc<T> : IDisposable
        {
            private readonly CancellationTokenSource cts;
            private readonly Predicate<T> pred;

            private volatile Tuple<T> firstResult;

            public MatchFunc(Predicate<T> pred)
            {
                this.cts = new CancellationTokenSource();
                this.pred = pred;
            }

            public bool Canceled => this.cts.IsCancellationRequested;

            public T Result
            {
                get
                {
                    if (this.firstResult == null)
                    {
                        throw new InvalidOperationException("No matching result.");
                    }

                    return this.firstResult.Item1;
                }
            }

            public async Task<T> RunAsync(Func<CancellationToken, Task<T>> func)
            {
                T result = default(T);
                try
                {
                    result = await func(this.cts.Token);
                    if (this.pred(result))
                    {
                        this.firstResult = Tuple.Create(result);
                        this.cts.Cancel();
                    }
                }
                catch (Exception)
                {
                }

                return result;
            }

            public void Dispose()
            {
                this.cts.Dispose();
            }
        }

The code got longer but I would argue it is easier to see what is happening. After that, a quick addition of another test to cover a five function scenario, all of which except the last hang until explicitly canceled, to make sure the implementation is sound with respect to cancellation (commit link). Then some minor cleanup to remove the unused result from the inner task (commit link), and a finishing touch to ensure we only ever assign a result and cancel exactly one (commit link), using our old friend CompareExchange:

        public static async Task<T> FirstAsync<T>(this IEnumerable<Func<CancellationToken, Task<T>>> funcs, Predicate<T> pred)
        {
            var tasks = new List<Task>();
            using (MatchFunc<T> match = new MatchFunc<T>(pred))
            {
                foreach (Func<CancellationToken, Task<T>> func in funcs)
                {
                    if (match.Canceled)
                    {
                        break;
                    }

                    Task task = match.RunAsync(func);
                    tasks.Add(task);
                }

                await Task.WhenAll(tasks);
                return match.Result;
            }
        }

        private sealed class MatchFunc<T> : IDisposable
        {
            private readonly CancellationTokenSource cts;
            private readonly Predicate<T> pred;

            private Tuple<T> firstResult;

            public MatchFunc(Predicate<T> pred)
            {
                this.cts = new CancellationTokenSource();
                this.pred = pred;
            }

            public bool Canceled => this.cts.IsCancellationRequested;

            public T Result
            {
                get
                {
                    if (this.firstResult == null)
                    {
                        throw new InvalidOperationException("No matching result.");
                    }

                    return this.firstResult.Item1;
                }
            }

            public async Task RunAsync(Func<CancellationToken, Task<T>> func)
            {
                try
                {
                    T result = await func(this.cts.Token);
                    if (this.pred(result))
                    {
                        this.Complete(result);
                    }
                }
                catch (Exception)
                {
                }
            }

            public void Dispose()
            {
                this.cts.Dispose();
            }

            private void Complete(T result)
            {
                if (Interlocked.CompareExchange(ref this.firstResult, Tuple.Create(result), null) == null)
                {
                    this.cts.Cancel();
                }
            }
        }

This function seems to work and solves the original use case in a repeatable way. Not bad for an hour’s work!

Leave a Reply

Your email address will not be published. Required fields are marked *