6 minute read

I recently had the odd but totally reasonable requirement to split a single C# async iterator between its consumers. If you’re not familiar with async iterators in C#, I’d recommend you read this article from the November 2019 issue of MSDN magazine. I’ll wait.

So now, what do I mean by ‘splitting an async iterator between its consumers’? Let’s assume we have an instance of IAsyncEnumerable<int> that’ll yield the numbers between 1 and 10:

var source = AsyncEnumerable.Range(1, 10);

A ‘consumer’ of source now is just some piece of code that has an iteration going on, i.e. it, at the very minimum, called GetAsyncEnumerator on source to get an IAsyncEnumerator<int> and hasn’t disposed of it yet. ‘Splitting’ source means applying an operator Split (that is yet to be written)

var split = source.Split();

such that the resulting IAsyncEnumerable<int> satisfies a couple of properties:

  1. Independent of the number of concurrent consumers of split, source will only ever be iterated by mostly one consumer.
  2. Consumers of split may try to consume any number of elements from it. More specifically, it is valid to finish iterating split at any time.
  3. Any element that source produces can be consumed only by excatly one consumer of split.
  4. When the last consumer of split is finshed iterating, the single consumer of source is disposed. Any next consumer of split will create a new single consumer on source.

In other terms, Split is a way of sharing an iterator between multiple concurrent consumers such that consumers take turns in consuming. For example, given source and split as defined above, concurrently running the consumers like this

var task1 = Task.Run(async () => await split.ToArrayAsync());
var task2 = Task.Run(async () => await split.ToArrayAsync());
var task3 = Task.Run(async () => await split.ToArrayAsync());

var array1 = await task1;
var array2 = await task2;
var array3 = await task3;

may end up with array1 = {1, 4, 7, 10}, array2 = {2, 5, 8} and array3 = {3, 6, 9}.

The code for the operator I present here is lock free, i.e. it should perform well in concurrent scenarios since it doesn’t ever block any threads. It is hopefully sufficiently commented too.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;

public static class AsyncEnumerableExtensions
{
    public static IAsyncEnumerable<T> Split<T>(this IAsyncEnumerable<T> source)
    {
        // Variables shared across all consumers of the split IAsyncEnumerable.

        // With a SemaphoreSlim, we can lock a protected code section
        // without blocking any threads.
        var asyncLock = new SemaphoreSlim(1);
        var currentEnumerator = default(RefCountEnumerator<T>?);

        return Core();
        
        async IAsyncEnumerable<T> Core([EnumeratorCancellation] CancellationToken ct = default)
        {
            var localEnumerator = default(RefCountEnumerator<T>?);

            // Get a valid RefCountEnumerator<T>, which is a wrapper for the source IAsyncEnumerator
            // which tracks the number of concurrent consumers. This is done lock-free, i.e. it doesn't
            // ever block any threads.
            while (localEnumerator == null)
            {
                localEnumerator = Volatile.Read(ref currentEnumerator);

                if (localEnumerator == null)
                {
                    // There's no valid value in currentEnumerator. Now there's a race for creation of a new
                    // RefCountEnumerator<T>. Try to win the race by setting an intermediate sentinel value.
                    if (Interlocked.CompareExchange(ref currentEnumerator, RefCountEnumerator<T>.Sentinel, null) == null)
                    {
                        //Success. This thread may now go ahead and get the source IAsyncEnumerator.
                        //Note: The ct parameter that this iterator gets will not be observed.
                        Interlocked.Exchange(ref currentEnumerator, localEnumerator = new RefCountEnumerator<T>(source));
                    }
                }
                else
                {
                    // There was a currentEnumerator. If it's still usable, Increment will return true.
                    if (!localEnumerator.Increment())
                    {
                        // But if it's not usable anymore, we reset currentEnumerator/localEnumerator to
                        // null and try again.
                        Interlocked.CompareExchange(ref currentEnumerator, null, localEnumerator);
                        localEnumerator = null;
                    }
                }
            }

            try
            {
                while (!ct.IsCancellationRequested)
                {
                    T value;

                    //... enter the protected section. This is the key part
                    // of the Split operator. The SemaphoreSlim makes sure
                    // that only one consumer at any time may call MoveNextAsync
                    // on the source IAsyncEnumerator.
                    await asyncLock.WaitAsync(ct);

                    try
                    {
                        if (!await localEnumerator.MoveNextAsync())
                            yield break;

                        value = localEnumerator.Current;
                    }
                    finally
                    {
                        asyncLock.Release();
                    }

                    yield return value;
                }
            }
            finally
            {
                // Whenever the consumer is done iterating, we end up here.
                // By calling Decrement we signal the RefCountEnumerator to
                // decrement the concurrent consumer count and possibly
                // dispose the source IAsyncEnumerator<T>.
                if (!await localEnumerator.Decrement())
                    Interlocked.CompareExchange(ref currentEnumerator, null, localEnumerator);
            }
        }
    }

    // We need to wrap the source-IAsyncEnumerator's lifecycle
    // because the current number of consumers must be tracked
    // alongside the source-IAsyncEnumerator itself. We can't easily
    // inline this into the Split-operator code itself as it would
    // make the concurrent lock-free part much harder.
    private sealed class RefCountEnumerator<T> : IAsyncEnumerator<T>
    {
        // The Sentinel instance is used to indicate that a particular thread
        // has successfully won the race for source-enumerator creation
        // and may now proceed safely.
        public static readonly RefCountEnumerator<T> Sentinel = new();

        private int _disposed;
        private int _count = 1;
        private readonly IAsyncEnumerator<T> _baseEnumerator;
        private readonly CancellationTokenSource _cts = new();

        private RefCountEnumerator() : this(Dummy())
        {

        }

        public RefCountEnumerator(IAsyncEnumerable<T> source)
        {
            _baseEnumerator = source
                .GetAsyncEnumerator(_cts.Token);
        }

        public async ValueTask DisposeAsync()
        {
            if (Interlocked.CompareExchange(ref _disposed, 1, 0) == 0)
            {
                _cts.Cancel();
                await _baseEnumerator.DisposeAsync();
            }
        }

        // We make sure that the wrapped IAsyncEnumerator's MoveNextAsync
        // method is not called repeatedly after it has returned false.
        // This is not a strong requirement per se but might be a good
        // idea depending of the implementation of the source iterator.
        public async ValueTask<bool> MoveNextAsync()
        {
            var ret = false;

            if (Volatile.Read(ref _disposed) == 0)
            {
                ret = await _baseEnumerator.MoveNextAsync();
                if (!ret)
                    await DisposeAsync();
            }

            return ret;
        }

        // Returns true if the wrapped IAsyncEnumerator<T> is still valid.
        public bool Increment()
        {
            while (true)
            {
                var currentCount = Volatile.Read(ref _count);

                if (currentCount == 0)
                    return false;

                if (Interlocked.CompareExchange(ref _count, currentCount + 1, currentCount) == currentCount)
                    return true;
            }
        }

        // Returns true if the wrapped IAsyncEnumerator<T> is still valid.
        public async ValueTask<bool> Decrement()
        {
            if (Interlocked.Decrement(ref _count) == 0)
            {
                await DisposeAsync();
                
                return false;
            }

            return true;
        }

        public T Current => _baseEnumerator.Current;

        private static async IAsyncEnumerable<T> Dummy()
        {
            yield break;
        }
    }
}

Some tests are in order:

using System.Collections.Generic;
using System.Linq;
using System.Reactive.Linq;
using System.Reactive.Threading.Tasks;
using System.Threading;
using System.Threading.Tasks;
using FluentAssertions;
using Xunit;

namespace BlogCodeTests
{
    public class AsyncEnumerableSplitTests
    {
        [Fact]
        public async Task Split_once()
        {
            var count = 100;
            var parallelism = 10;

            var source = AsyncEnumerable.Create(Range);
            var split = source.Split();

            //Range but with a bit of delay between elements
            async IAsyncEnumerator<int> Range(CancellationToken ct)
            {
                for (var i = 0; i < count; i++)
                {
                    yield return i;

                    await Task.Delay(5, ct);
                }
            }

            // Concurrently iterate `split` into buckets.
            var buckets = await Enumerable
                .Repeat(split, parallelism)
                .ToObservable()
                .SelectMany(_ => _
                    .ToObservable()
                    .ToArray())
                .ToArray()
                .ToTask();

            // There's as many buckets as parallelism says.
            buckets
                .Should()
                .HaveCount(parallelism);

            // There's at least one element in any bucket.
            buckets
                .All(x => x.Length > 0)
                .Should()
                .BeTrue();

            // All buckets sizes add up properly.
            buckets
                .Sum(x => x.Length)
                .Should()
                .Be(count);

            // All the elements combined without duplicates
            // are source's output!
            buckets
                .SelectMany(x => x)
                .Distinct()
                .Should()
                .HaveCount(count);
        }
        
        [Fact]
        public async Task Split_twice()
        {
            //Assert that after a complete iteration of source,
            //it can be done again!

            var count = 10;
            var parallelism = 10;

            var source = AsyncEnumerable.Create(Range);
            var split = source.Split();

            //Range but with a bit of delay between elements
            async IAsyncEnumerator<int> Range(CancellationToken ct)
            {
                for (var i = 0; i < count; i++)
                {
                    yield return i;

                    await Task.Delay(10, ct);
                }
            }

            var buckets1 = await Enumerable
                .Repeat(split, parallelism)
                .ToObservable()
                .SelectMany(_ => _
                    .ToObservable()
                    .ToArray())
                .ToArray()
                .ToTask();

            var buckets2 = await Enumerable
                .Repeat(split, parallelism)
                .ToObservable()
                .SelectMany(_ => _
                    .ToObservable()
                    .ToArray())
                .ToArray()
                .ToTask();

            buckets1
                .SelectMany(x => x)
                .Should()
                .BeEquivalentTo(buckets2.SelectMany(x => x));
        }
    }
}

Happy coding!