Tag Archives: argument validation

Easy unit testing of null argument validation (C# 8 edition)

A few years ago, I blogged about a way to automate unit testing of null argument validation. Its usage looked like this:

[Fact]
public void FullOuterJoin_Throws_If_Argument_Is_Null()
{
    var left = Enumerable.Empty<int>();
    var right = Enumerable.Empty<int>();
    TestHelper.AssertThrowsWhenArgumentNull(
        () => left.FullOuterJoin(right, x => x, y => y, (k, x, y) => 0, 0, 0, null),
        "left", "right", "leftKeySelector", "rightKeySelector", "resultSelector");
}

Basically, for each of the specified parameters, the AssertThrowsWhenArgumentNull method rewrites the lambda expression by replacing the corresponding argument with null, compiles and executes it, and checks that it throws an ArgumentNullException with the appropriate parameter name. This method has served me well for many years, as it drastically reduces the amount of code to test argument validation. However, I wasn’t completely satisfied with it, because I still had to specify the names of the non-nullable parameters explicitly…

C# 8 to the rescue

Yesterday, I was working on enabling C# 8 non-nullable reference types on an old library, and I realized that I could take advantage of the nullable metadata to automatically detect which parameters are non-nullable.

Basically, when you compile a library with nullable reference types enabled, method parameters can be annotated with a [Nullable(x)] attribute, where x is a byte value that indicates the nullability of the parameter (it’s actually slightly more complicated than that, see Jon Skeet’s article on the subject). Additionally, there can be a [NullableContext(x)] attribute on the method or type that indicates the default nullability for the method or type; if a parameter doesn’t have the [Nullable] attribute, the default nullability applies.

Using these facts, it’s possible to update my old AssertThrowsWhenArgumentNull method to make it detect non-nullable parameters automatically. Here’s the result:

using System;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using FluentAssertions;

static class TestHelper
{
    private const string NullableContextAttributeName = "System.Runtime.CompilerServices.NullableContextAttribute";
    private const string NullableAttributeName = "System.Runtime.CompilerServices.NullableAttribute";

    public static void AssertThrowsWhenArgumentNull(Expression<Action> expr)
    {
        var realCall = expr.Body as MethodCallExpression;
        if (realCall == null)
            throw new ArgumentException("Expression body is not a method call", nameof(expr));
        
        var method = realCall.Method;
        var nullableContextAttribute =
            method.CustomAttributes
            .FirstOrDefault(a => a.AttributeType.FullName == NullableContextAttributeName)
            ??
            method.DeclaringType.GetTypeInfo().CustomAttributes
            .FirstOrDefault(a => a.AttributeType.FullName == NullableContextAttributeName);

        if (nullableContextAttribute is null)
            throw new InvalidOperationException($"The method '{method}' is not in a nullable enable context. Can't determine non-nullable parameters.");

        var defaultNullability = (Nullability)(byte)nullableContextAttribute.ConstructorArguments[0].Value;

        var realArgs = realCall.Arguments;
        var parameters = method.GetParameters();
        var paramIndexes = parameters
            .Select((p, i) => new { p, i })
            .ToDictionary(x => x.p.Name, x => x.i);
        var paramTypes = parameters
            .ToDictionary(p => p.Name, p => p.ParameterType);

        var nonNullableRefParams = parameters
            .Where(p => !p.ParameterType.GetTypeInfo().IsValueType && GetNullability(p, defaultNullability) == Nullability.NotNull);

        foreach (var param in nonNullableRefParams)
        {
            var paramName = param.Name;
            var args = realArgs.ToArray();
            args[paramIndexes[paramName]] = Expression.Constant(null, paramTypes[paramName]);
            var call = Expression.Call(realCall.Object, method, args);
            var lambda = Expression.Lambda<Action>(call);
            var action = lambda.Compile();
            action.ShouldThrow<ArgumentNullException>($"because parameter '{paramName}' is not nullable")
                .And.ParamName.Should().Be(paramName);
        }
    }

    private enum Nullability
    {
        Oblivious = 0,
        NotNull = 1,
        Nullable = 2
    }

    private static Nullability GetNullability(ParameterInfo parameter, Nullability defaultNullability)
    {
        if (parameter.ParameterType.GetTypeInfo().IsValueType)
            return Nullability.NotNull;

        var nullableAttribute = parameter.CustomAttributes
            .FirstOrDefault(a => a.AttributeType.FullName == NullableAttributeName);

        if (nullableAttribute is null)
            return defaultNullability;

        var firstArgument = nullableAttribute.ConstructorArguments.First();
        if (firstArgument.ArgumentType == typeof(byte))
        {
            var value = (byte)firstArgument.Value;
            return (Nullability)value;
        }
        else
        {
            var values = (ReadOnlyCollection<CustomAttributeTypedArgument>)firstArgument.Value;

            // Probably shouldn't happen
            if (values.Count == 0)
                return defaultNullability;

            var value = (byte)values[0].Value;

            return (Nullability)value;
        }
    }
}

The unit test is now even simpler, since there’s no need to specify the parameters to validate:

[Fact]
public void FullOuterJoin_Throws_If_Argument_Is_Null()
{
    var left = Enumerable.Empty<int>();
    var right = Enumerable.Empty<int>();
    TestHelper.AssertThrowsWhenArgumentNull(
        () => left.FullOuterJoin(right, x => x, y => y, (k, x, y) => 0, 0, 0, null));
}

It will automatically check that each non-nullable parameter is properly validated.

Happy coding!

Easy unit testing of null argument validation

When unit testing a method, one of the things to test is argument validation : for instance, ensure that the method throws a ArgumentNullException when a null argument is passed for a parameter that isn’t allowed to be null. Writing this kind of test is very easy, but it’s also a tedious and repetitive task, especially if the method has many parameters… So I wrote a method that automates part of this task: it tries to pass null for each of the specified arguments, and asserts that the method throws an ArgumentNullException. Here’s an example that tests a FullOuterJoin extension method:

[Test]
public void FullOuterJoin_Throws_If_Argument_Null()
{
    var left = Enumerable.Empty<int>();
    var right = Enumerable.Empty<int>();
    TestHelper.AssertThrowsWhenArgumentNull(
        () => left.FullOuterJoin(right, x => x, y => y, (k, x, y) => 0, 0, 0, null),
        "left", "right", "leftKeySelector", "rightKeySelector", "resultSelector");
}

The first parameter is a lambda expression that represents how to call the method. In this lambda, you should only pass valid arguments. The following parameters are the names of the parameters that are not allowed to be null. For each of the specified names, AssertThrowsWhenArgumentNull will replace the corresponding argument with null in the provided lambda, compile and invoke the lambda, and assert that the method throws a ArgumentNullException.

Using this method, instead of writing a test for each of the arguments that are not allowed to be null, you only need one test.

Here’s the code for the TestHelper.AssertThrowsWhenArgumentNull method (you can also find it on Gist):

using System;
using System.Linq;
using System.Linq.Expressions;
using NUnit.Framework;

namespace MyLibrary.Tests
{
    static class TestHelper
    {
        public static void AssertThrowsWhenArgumentNull(Expression<TestDelegate> expr, params string[] paramNames)
        {
            var realCall = expr.Body as MethodCallExpression;
            if (realCall == null)
                throw new ArgumentException("Expression body is not a method call", "expr");

            var realArgs = realCall.Arguments;
            var paramIndexes = realCall.Method.GetParameters()
                .Select((p, i) => new { p, i })
                .ToDictionary(x => x.p.Name, x => x.i);
            var paramTypes = realCall.Method.GetParameters()
                .ToDictionary(p => p.Name, p => p.ParameterType);
            
            

            foreach (var paramName in paramNames)
            {
                var args = realArgs.ToArray();
                args[paramIndexes[paramName]] = Expression.Constant(null, paramTypes[paramName]);
                var call = Expression.Call(realCall.Method, args);
                var lambda = Expression.Lambda<TestDelegate>(call);
                var action = lambda.Compile();
                var ex = Assert.Throws<ArgumentNullException>(action, "Expected ArgumentNullException for parameter '{0}', but none was thrown.", paramName);
                Assert.AreEqual(paramName, ex.ParamName);
            }
        }

    }
}

Note that it is written for NUnit, but can easily be adapted to other unit test frameworks.

I used this method in my Linq.Extras library, which provides many additional extension methods for working with sequences and collections (including the FullOuterJoin method mentioned above).