Laziness is a virtue

Spread the love

Sometimes you want lazy initialization but your initialization function is asynchronous. Stephen Toub wrote about this conundrum years ago on the pfxteam blog. The solution as he describes it is fairly straightforward — use Lazy<T> combined with Task<T> and you’re 90% there. Stephen Cleary later packaged this up as AsyncLazy within his popular Nito.AsyncEx library.

One quibble I have with these implementations is that they force a thread switch due to the Task.Factory.StartNew call wrapping the initialization function. If you have a well-behaved asynchronous routine that is guaranteed nonblocking (and you should!), you don’t need this extra layer. So we ostensibly know the solution, but let’s apply the rigor of TDD to show that it meets our specifications.

Here is the starting code:

public sealed class AsyncLazy<T>
{
    private readonly Lazy<Task<T>> lazy;

    public AsyncLazy(Func<Task<T>> initializeAsync)
    {
        this.lazy = new Lazy<Task<T>>(initializeAsync, LazyThreadSafetyMode.ExecutionAndPublication);
    }

    public Task<T> GetAsync()
    {
        return this.lazy.Value;
    }
}

The first few tests show that the initialization and cached value semantics are correct:

[TestMethod]
public void ShouldInitializeOnFirstCallCompleteSync()
{
    int count = 0;
    AsyncLazy<int> lazy = new AsyncLazy<int>(() => Task.FromResult(++count));

    Task<int> task = lazy.GetAsync();

    Assert.IsTrue(task.IsCompleted);
    Assert.AreEqual(1, task.Result);
}

[TestMethod]
public void ShouldReturnCachedValueAfterInitialization()
{
    int count = 0;
    AsyncLazy<int> lazy = new AsyncLazy<int>(() => Task.FromResult(++count));

    lazy.GetAsync();
    Task<int> task = lazy.GetAsync();

    Assert.IsTrue(task.IsCompleted);
    Assert.AreEqual(1, task.Result);
}

[TestMethod]
public void ShouldLeaveTaskPendingUntilInitializationCompletes()
{
    TaskCompletionSource<int> tcs = new TaskCompletionSource<int>();
    AsyncLazy<int> lazy = new AsyncLazy<int>(() => tcs.Task);

    Task<int> task = lazy.GetAsync();

    Assert.IsFalse(task.IsCompleted);

    tcs.SetResult(123);

    Assert.IsTrue(task.IsCompleted);
    Assert.AreEqual(123, task.Result);
}

[TestMethod]
public void ShouldCompleteAllPendingTasksInOrderAfterInitializationCompletes()
{
    List<string> steps = new List<string>();
    TaskCompletionSource<int> tcs = new TaskCompletionSource<int>();
    AsyncLazy<int> lazy = new AsyncLazy<int>(() => tcs.Task);

    lazy.GetAsync().ContinueWith(t => steps.Add("first:" + t.Result), TaskContinuationOptions.ExecuteSynchronously);
    lazy.GetAsync().ContinueWith(t => steps.Add("second:" + t.Result), TaskContinuationOptions.ExecuteSynchronously);

    tcs.SetResult(321);

    CollectionAssert.AreEqual(new string[] { "first:321", "second:321" }, steps);
}

Now we move on to the exceptional cases. Let’s review the behavior of the core Lazy<T> class according to the MSDN documentation:

…if the factory method throws an exception the first time a thread tries to access the Value property of the Lazy<T> object, the same exception is thrown on every subsequent attempt.

The following tests (shown post-refactoring) should prove the same about our async implementation. Note that they cover exception variations for both synchronous (i.e. thrown immediately from the factory method) and asynchronous (i.e. from the completion of a faulted Task<T>):

[TestMethod]
public void ShouldThrowExceptionOnGetIfInitializationFailsSync()
{
    ShouldThrowExceptionOnGetIfInitializationFails(true);
}

[TestMethod]
public void ShouldThrowCachedExceptionOnSecondGetIfInitializationFailsSync()        
{
    ShouldThrowCachedExceptionOnSecondGetIfInitializationFails(true);
}

[TestMethod]
public void ShouldThrowExceptionOnGetIfInitializationFailsAsync()
{
    ShouldThrowExceptionOnGetIfInitializationFails(false);
}

[TestMethod]
public void ShouldThrowCachedExceptionOnSecondGetIfInitializationFailsAsync()
{
    ShouldThrowCachedExceptionOnSecondGetIfInitializationFails(false);
}

private static void ShouldThrowExceptionOnGetIfInitializationFails(bool isSync)
{
    Exception expected = new InvalidOperationException("Expected exception.");
    AsyncLazy<int> lazy = new AsyncLazy<int>(() => Throw(0, expected, isSync));

    Task<int> task = lazy.GetAsync();

    AssertTaskCompletedWithException(expected, task);
}

private static void ShouldThrowCachedExceptionOnSecondGetIfInitializationFails(bool isSync)
{
    int count = 0;
    Exception expected = new InvalidOperationException("Expected exception.");
    AsyncLazy<int> lazy = new AsyncLazy<int>(() => Throw(++count, expected, isSync));

    lazy.GetAsync();
    Task<int> task = lazy.GetAsync();

    Assert.AreEqual(1, count);
    AssertTaskCompletedWithException(expected, task);
}

private static void AssertTaskCompletedWithException<TException>(TException expected, Task task) where TException : Exception
{
    Assert.IsTrue(task.IsCompleted);
    Assert.IsTrue(task.IsFaulted);
    Assert.IsNotNull(task.Exception);
    Assert.AreEqual(1, task.Exception.InnerExceptions.Count);
    Assert.AreSame(expected, task.Exception.InnerException);
}

private static Task<T> Throw<T, TException>(T value, TException exception, bool isSync) where TException : Exception
{
    if (isSync)
    {
        throw exception;
    }
    else
    {
        TaskCompletionSource<T> tcs = new TaskCompletionSource<T>();
        tcs.SetException(exception);
        return tcs.Task;
    }
}

But this is where we run into trouble — the synchronous tests fail because the exception does not translate to a faulted task in our current implementation. Rather it is thrown directly from GetValueAsync. I have deemed this undesirable and fixed the implementation to follow the test specification:

public sealed class AsyncLazy<T>
{
    private readonly Lazy<Task<T>> lazy;
    public AsyncLazy(Func<Task<T>> initializeAsync)
    {
        this.lazy = new Lazy<Task<T>>(WrapException(initializeAsync), LazyThreadSafetyMode.ExecutionAndPublication);
    }

    public Task<T> GetAsync()
    {
        return this.lazy.Value;
    }

    private static Func<Task<T>> WrapException(Func<Task<T>> initializeAsync)
    {
        return delegate
        {
            try
            {
                return initializeAsync();
            }
            catch (Exception e)
            {
                TaskCompletionSource<T> tcs = new TaskCompletionSource<T>();
                tcs.SetException(e);
                return tcs.Task;
            }
        };
    }
}

The key to the fix is the WrapException method. This allows us to catch synchronous exceptions and translate them to asynchronous ones (wrapped inside a Task).

And there you have it — a slightly new twist on an old pattern.

Leave a Reply

Your email address will not be published. Required fields are marked *