Efficient concurrency prevention

Spread the love

Sometimes you want asynchrony but not concurrency. For example, if you are writing data to a file, you should generally prefer asynchronous I/O, but you probably don’t want 10 other competing callers to corrupt the contents.

Perhaps the simplest way is to just protect the call with a lock, e.g.:

public class MyFile
{
    private readonly object syncRoot;
    private bool writing;
    // ...

    public MyFile()
    {
        this.syncRoot = new object();
    }

    public async Task WriteAsync(byte[] data)
    {
        lock (this.syncRoot)
        {
           if (writing)
           {
               throw new InvalidOperationException("Conflict!");
           }

           writing = true;
        }

        try
        {
            // ... do work here...
        }
        finally
        {
            lock (this.syncRoot)
            {
                writing = false;
            }
        }
    }
}

This works fine and is simple enough. But that lock might be overkill. A lock can be expensive, especially when you always force the caller to acquire it as above. Ideally, you don’t want to make the correct use case significantly slower just to account for possible incorrect usages. (But of course, don’t just take my word for it — measure it and see!)

Let me introduce a little class that can simplify this type of concurrency prevention without locks. Additionally, it assumes that the operation lifetime is managed by TaskCompletionSource; it could, for example, be used with the VirtualDisk code in my previous post where the bulk of the work is actually happening elsewhere in the Windows kernel.

public class AsyncOperation
{
    private TaskCompletionSource<bool> tcs;

    public Task Start()
    {
        TaskCompletionSource<bool> current = new TaskCompletionSource<bool>();
        if (Interlocked.CompareExchange(ref this.tcs, current, null) != null)
        {
            throw new InvalidOperationException("An operation is already in progress.");
        }

        return current.Task;
    }

    public void Complete()
    {
        TaskCompletionSource<bool> current = Interlocked.Exchange(ref this.tcs, null);
        if (current == null)
        {
            // Programming error!!
            throw new InvalidOperationException("There is no active operation.");
        }

        current.SetResult(false);
    }
}

That’s it — just a few interlocked operations and we’re good to go. (The “complete with exception” path has been elided for brevity, but it is not too difficult to imagine how it would look.)

But does it work? Here are some unit tests showing the basic state coverage:

[TestClass]
public class AsyncOperationTest
{
    [TestMethod]
    public void ShouldReturnTaskRepresentingOperationOnStart()
    {
        AsyncOperation operation = new AsyncOperation();

        Task task = operation.Start();

        Assert.AreEqual(TaskStatus.WaitingForActivation, task.Status);
    }

    [TestMethod]
    public void ShouldCompleteStartedTaskOnComplete()
    {
        AsyncOperation operation = new AsyncOperation();

        Task task = operation.Start();
        operation.Complete();

        Assert.AreEqual(TaskStatus.RanToCompletion, task.Status);
    }

    [TestMethod]
    public void ShouldFailToCompleteIfNotStarted()
    {
        AsyncOperation operation = new AsyncOperation();

        try
        {
            operation.Complete();
            Assert.Fail("Expected exception not thrown.");
        }
        catch (InvalidOperationException)
        {
        }
    }

    [TestMethod]
    public void ShouldFailToCompleteIfAlreadyCompleted()
    {
        AsyncOperation operation = new AsyncOperation();

        operation.Start();
        operation.Complete();

        try
        {
            operation.Complete();
            Assert.Fail("Expected exception not thrown.");
        }
        catch (InvalidOperationException)
        {
        }
    }


    [TestMethod]
    public void ShouldFailToStartIfAlreadyStartedWithoutInterferingWithInitialTask()
    {
        AsyncOperation operation = new AsyncOperation();

        Task task = operation.Start();

        try
        {
            operation.Start();
            Assert.Fail("Expected exception not thrown.");
        }
        catch (InvalidOperationException)
        {
        }

        Assert.AreEqual(TaskStatus.WaitingForActivation, task.Status);

        operation.Complete();

        Assert.AreEqual(TaskStatus.RanToCompletion, task.Status);
    }
}

But we’re dealing with concurrency and race conditions and other such not-very-easy-to-show-in-a-unit-test factors. Let’s also write a few “slow” tests which will run a lot more operations designed to conflict with each other. The basic premise is that only one caller should ever be allowed to initiate an operation instance and only one caller should be able to complete a running operation, no matter how many are trying. A test to show that is ostensibly true might look like the following, using a thread barrier to synchronize the work phases and countdown event to track any consistency violations:

[TestClass]
public class AsyncOperationSlowTest
{
    [TestMethod]
    public void ShouldOnlyAllowOneTaskInFlightWithManyTryingToStart()
    {
        Task[] tasks = new Task[16];
        AsyncOperation operation = new AsyncOperation();
        CountdownEvent counter = new CountdownEvent(2);

        Action afterWork = delegate
        {
            counter.Reset();
            operation.Complete();
        };

        Barrier barrier = new Barrier(tasks.Length, b => afterWork());

        Action<CancellationToken> doWork = delegate (CancellationToken token)
        {
            int iterations = 0;
            int exceptions = 0;
            try
            {
                while (!token.IsCancellationRequested)
                {
                    ++iterations;
                    bool started = false;
                    try
                    {
                        operation.Start();
                        started = true;
                    }
                    catch (InvalidOperationException)
                    {
                        ++exceptions;
                    }

                    if (started)
                    {
                        bool consistencyViolated = counter.Signal();
                        if (consistencyViolated)
                        {
                            throw new InvalidDataException("Consistency violated! Multiple calls to Start were allowed.");
                        }
                    }

                    barrier.SignalAndWait(token);
                }
            }
            catch (OperationCanceledException)
            {
            }

            Console.WriteLine("Thread {0} ran {1} iterations and caught {2} exceptions.", Thread.CurrentThread.ManagedThreadId, iterations, exceptions);
        };


        using (CancellationTokenSource cts = new CancellationTokenSource(TimeSpan.FromSeconds(30.0d)))
        {
            for (int i = 0; i < tasks.Length; ++i)
            {
                tasks[i] = Task.Factory.StartNew(() => doWork(cts.Token), TaskCreationOptions.LongRunning);
            }

            try
            {
                Task.WaitAny(tasks, cts.Token);
            }
            catch (OperationCanceledException)
            {
            }

            Task.WaitAll(tasks);
        }
    }

    [TestMethod]
    public void ShouldOnlyAllowOneTaskInFlightWithManyTryingToComplete()
    {
        Task[] tasks = new Task[16];
        AsyncOperation operation = new AsyncOperation();
        CountdownEvent counter = new CountdownEvent(2);

        Action afterWork = delegate
        {
            counter.Reset();
            operation.Start();
        };

        afterWork();

        Barrier barrier = new Barrier(tasks.Length, b => afterWork());

        Action<CancellationToken> doWork = delegate (CancellationToken token)
        {
            int iterations = 0;
            int exceptions = 0;
            try
            {
                while (!token.IsCancellationRequested)
                {
                    ++iterations;
                    bool completed = false;
                    try
                    {
                        operation.Complete();
                        completed = true;
                    }
                    catch (InvalidOperationException)
                    {
                        ++exceptions;
                    }

                    if (completed)
                    {
                        bool consistencyViolated = counter.Signal();
                        if (consistencyViolated)
                        {
                            throw new InvalidDataException("Consistency violated! Multiple calls to Complete were allowed.");
                        }
                    }

                    barrier.SignalAndWait(token);
                }
            }
            catch (OperationCanceledException)
            {
            }

            Console.WriteLine("Thread {0} ran {1} iterations and caught {2} exceptions.", Thread.CurrentThread.ManagedThreadId, iterations, exceptions);
        };


        using (CancellationTokenSource cts = new CancellationTokenSource(TimeSpan.FromSeconds(30.0d)))
        {
            for (int i = 0; i < tasks.Length; ++i)
            {
                tasks[i] = Task.Factory.StartNew(() => doWork(cts.Token), TaskCreationOptions.LongRunning);
            }

            try
            {
                Task.WaitAny(tasks, cts.Token);
            }
            catch (OperationCanceledException)
            {
            }

            Task.WaitAll(tasks);
        }
    }
}

(I leave it as an exercise to the reader to eliminate the duplication…)

These tests when run in Release mode on my 12-core system output these results:

Test Name: ShouldOnlyAllowOneTaskInFlightWithManyTryingToStart
Test Outcome: Passed
Result StandardOutput:
Thread 39 ran 370480 iterations and caught 346010 exceptions.
Thread 42 ran 370480 iterations and caught 347760 exceptions.
Thread 35 ran 370480 iterations and caught 349375 exceptions.
Thread 41 ran 370480 iterations and caught 349517 exceptions.
Thread 37 ran 370480 iterations and caught 347413 exceptions.
Thread 38 ran 370480 iterations and caught 345398 exceptions.
Thread 36 ran 370480 iterations and caught 347464 exceptions.
Thread 4 ran 370480 iterations and caught 347261 exceptions.
Thread 28 ran 370480 iterations and caught 345826 exceptions.
Thread 32 ran 370480 iterations and caught 346443 exceptions.
Thread 33 ran 370480 iterations and caught 345520 exceptions.
Thread 40 ran 370480 iterations and caught 349619 exceptions.
Thread 34 ran 370480 iterations and caught 349496 exceptions.
Thread 8 ran 370480 iterations and caught 347067 exceptions.
Thread 6 ran 370480 iterations and caught 346850 exceptions.
Thread 26 ran 370480 iterations and caught 346181 exceptions.

Test Name: ShouldOnlyAllowOneTaskInFlightWithManyTryingToComplete
Test Outcome: Passed
Result StandardOutput:
Thread 43 ran 377994 iterations and caught 353909 exceptions.
Thread 55 ran 377994 iterations and caught 356977 exceptions.
Thread 50 ran 377994 iterations and caught 357073 exceptions.
Thread 53 ran 377994 iterations and caught 352969 exceptions.
Thread 44 ran 377994 iterations and caught 352972 exceptions.
Thread 47 ran 377994 iterations and caught 353356 exceptions.
Thread 49 ran 377994 iterations and caught 356587 exceptions.
Thread 45 ran 377994 iterations and caught 352553 exceptions.
Thread 51 ran 377994 iterations and caught 354122 exceptions.
Thread 58 ran 377994 iterations and caught 353903 exceptions.
Thread 52 ran 377994 iterations and caught 353760 exceptions.
Thread 54 ran 377994 iterations and caught 353365 exceptions.
Thread 56 ran 377994 iterations and caught 357139 exceptions.
Thread 9 ran 377994 iterations and caught 353156 exceptions.
Thread 48 ran 377994 iterations and caught 353143 exceptions.
Thread 57 ran 377994 iterations and caught 354926 exceptions.

Note that the completion race test is a bit faster than the start race test because of the unconditional heap allocation of TaskCompletionSource in AsyncOperation.Start(). I haven’t yet come up with a way to avoid this in managed code… maybe an enterprising reader can help.

Leave a Reply

Your email address will not be published. Required fields are marked *