Tag Archives: null check

Easy unit testing of null argument validation (C# 8 edition)

A few years ago, I blogged about a way to automate unit testing of null argument validation. Its usage looked like this:

[Fact]
public void FullOuterJoin_Throws_If_Argument_Is_Null()
{
    var left = Enumerable.Empty<int>();
    var right = Enumerable.Empty<int>();
    TestHelper.AssertThrowsWhenArgumentNull(
        () => left.FullOuterJoin(right, x => x, y => y, (k, x, y) => 0, 0, 0, null),
        "left", "right", "leftKeySelector", "rightKeySelector", "resultSelector");
}

Basically, for each of the specified parameters, the AssertThrowsWhenArgumentNull method rewrites the lambda expression by replacing the corresponding argument with null, compiles and executes it, and checks that it throws an ArgumentNullException with the appropriate parameter name. This method has served me well for many years, as it drastically reduces the amount of code to test argument validation. However, I wasn’t completely satisfied with it, because I still had to specify the names of the non-nullable parameters explicitly…

C# 8 to the rescue

Yesterday, I was working on enabling C# 8 non-nullable reference types on an old library, and I realized that I could take advantage of the nullable metadata to automatically detect which parameters are non-nullable.

Basically, when you compile a library with nullable reference types enabled, method parameters can be annotated with a [Nullable(x)] attribute, where x is a byte value that indicates the nullability of the parameter (it’s actually slightly more complicated than that, see Jon Skeet’s article on the subject). Additionally, there can be a [NullableContext(x)] attribute on the method or type that indicates the default nullability for the method or type; if a parameter doesn’t have the [Nullable] attribute, the default nullability applies.

Using these facts, it’s possible to update my old AssertThrowsWhenArgumentNull method to make it detect non-nullable parameters automatically. Here’s the result:

using System;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using FluentAssertions;

static class TestHelper
{
    private const string NullableContextAttributeName = "System.Runtime.CompilerServices.NullableContextAttribute";
    private const string NullableAttributeName = "System.Runtime.CompilerServices.NullableAttribute";

    public static void AssertThrowsWhenArgumentNull(Expression<Action> expr)
    {
        var realCall = expr.Body as MethodCallExpression;
        if (realCall == null)
            throw new ArgumentException("Expression body is not a method call", nameof(expr));
        
        var method = realCall.Method;
        var nullableContextAttribute =
            method.CustomAttributes
            .FirstOrDefault(a => a.AttributeType.FullName == NullableContextAttributeName)
            ??
            method.DeclaringType.GetTypeInfo().CustomAttributes
            .FirstOrDefault(a => a.AttributeType.FullName == NullableContextAttributeName);

        if (nullableContextAttribute is null)
            throw new InvalidOperationException($"The method '{method}' is not in a nullable enable context. Can't determine non-nullable parameters.");

        var defaultNullability = (Nullability)(byte)nullableContextAttribute.ConstructorArguments[0].Value;

        var realArgs = realCall.Arguments;
        var parameters = method.GetParameters();
        var paramIndexes = parameters
            .Select((p, i) => new { p, i })
            .ToDictionary(x => x.p.Name, x => x.i);
        var paramTypes = parameters
            .ToDictionary(p => p.Name, p => p.ParameterType);

        var nonNullableRefParams = parameters
            .Where(p => !p.ParameterType.GetTypeInfo().IsValueType && GetNullability(p, defaultNullability) == Nullability.NotNull);

        foreach (var param in nonNullableRefParams)
        {
            var paramName = param.Name;
            var args = realArgs.ToArray();
            args[paramIndexes[paramName]] = Expression.Constant(null, paramTypes[paramName]);
            var call = Expression.Call(realCall.Object, method, args);
            var lambda = Expression.Lambda<Action>(call);
            var action = lambda.Compile();
            action.ShouldThrow<ArgumentNullException>($"because parameter '{paramName}' is not nullable")
                .And.ParamName.Should().Be(paramName);
        }
    }

    private enum Nullability
    {
        Oblivious = 0,
        NotNull = 1,
        Nullable = 2
    }

    private static Nullability GetNullability(ParameterInfo parameter, Nullability defaultNullability)
    {
        if (parameter.ParameterType.GetTypeInfo().IsValueType)
            return Nullability.NotNull;

        var nullableAttribute = parameter.CustomAttributes
            .FirstOrDefault(a => a.AttributeType.FullName == NullableAttributeName);

        if (nullableAttribute is null)
            return defaultNullability;

        var firstArgument = nullableAttribute.ConstructorArguments.First();
        if (firstArgument.ArgumentType == typeof(byte))
        {
            var value = (byte)firstArgument.Value;
            return (Nullability)value;
        }
        else
        {
            var values = (ReadOnlyCollection<CustomAttributeTypedArgument>)firstArgument.Value;

            // Probably shouldn't happen
            if (values.Count == 0)
                return defaultNullability;

            var value = (byte)values[0].Value;

            return (Nullability)value;
        }
    }
}

The unit test is now even simpler, since there’s no need to specify the parameters to validate:

[Fact]
public void FullOuterJoin_Throws_If_Argument_Is_Null()
{
    var left = Enumerable.Empty<int>();
    var right = Enumerable.Empty<int>();
    TestHelper.AssertThrowsWhenArgumentNull(
        () => left.FullOuterJoin(right, x => x, y => y, (k, x, y) => 0, 0, 0, null));
}

It will automatically check that each non-nullable parameter is properly validated.

Happy coding!

Automating null checks with Linq expressions

The problem

Have you ever written code like the following ?

X xx = GetX();
string name = "Default";
if (xx != null && xx.Foo != null && xx.Foo.Bar != null && xx.Foo.Bar.Baz != null)
{
    name = xx.Foo.Bar.Baz.Name;
}

I bet you have ! You just need to get the value of xx.Foo.Bar.Baz.Name, but you have to test every intermediate object to ensure that it’s not null. It can quickly become annoying if the property you need is nested in a deep object graph….

A solution

Linq offers a very interesting feature which can help solve that problem : expressions. C# 3.0 makes it possible to retrieve the abstract syntax tree (AST) of a lambda expression, and perform all kinds of manipulations on it. It is also possible to dynamically generate an AST, compile it to obtain a delegate, and execute it.

How is this related to the problem described above ? Well, Linq makes it possible to analyse the AST for the expression that accesses the xx.Foo.Bar.Baz.Name property, and rewrite that AST to insert null checks where needed. So we’re going to create a NullSafeEval extension method, which takes as a parameter the lambda expression defining how to access a property, and the default value to return if a null object is encountered along the way.

That method will transform the expression xx.Foo.Bar.Baz.Name into that :

    (xx == null)
    ? defaultValue
    : (xx.Foo == null)
      ? defaultValue
      : (xx.Foo.Bar == null)
        ? defaultValue
        : (xx.Foo.Bar.Baz == null)
          ? defaultValue
          : xx.Foo.Bar.Baz.Name;

Here’s the implementation of the NullSafeEval method :

        public static TResult NullSafeEval<TSource, TResult>(this TSource source, Expression<Func<TSource, TResult>> expression, TResult defaultValue)
        {
            var safeExp = Expression.Lambda<Func<TSource, TResult>>(
                NullSafeEvalWrapper(expression.Body, Expression.Constant(defaultValue)),
                expression.Parameters[0]);

            var safeDelegate = safeExp.Compile();
            return safeDelegate(source);
        }

        private static Expression NullSafeEvalWrapper(Expression expr, Expression defaultValue)
        {
            Expression obj;
            Expression safe = expr;

            while (!IsNullSafe(expr, out obj))
            {
                var isNull = Expression.Equal(obj, Expression.Constant(null));

                safe =
                    Expression.Condition
                    (
                        isNull,
                        defaultValue,
                        safe
                    );

                expr = obj;
            }
            return safe;
        }

        private static bool IsNullSafe(Expression expr, out Expression nullableObject)
        {
            nullableObject = null;

            if (expr is MemberExpression || expr is MethodCallExpression)
            {
                Expression obj;
                MemberExpression memberExpr = expr as MemberExpression;
                MethodCallExpression callExpr = expr as MethodCallExpression;

                if (memberExpr != null)
                {
                    // Static fields don't require an instance
                    FieldInfo field = memberExpr.Member as FieldInfo;
                    if (field != null && field.IsStatic)
                        return true;

                    // Static properties don't require an instance
                    PropertyInfo property = memberExpr.Member as PropertyInfo;
                    if (property != null)
                    {
                        MethodInfo getter = property.GetGetMethod();
                        if (getter != null && getter.IsStatic)
                            return true;
                    }
                    obj = memberExpr.Expression;
                }
                else
                {
                    // Static methods don't require an instance
                    if (callExpr.Method.IsStatic)
                        return true;

                    obj = callExpr.Object;
                }

                // Value types can't be null
                if (obj.Type.IsValueType)
                    return true;

                // Instance member access or instance method call is not safe
                nullableObject = obj;
                return false;
            }
            return true;
        }

In short, this code walks up the lambda expression tree, and surrounds each property access or instance method call with a conditional expression (condition ? value if true : value if false).

And here’s how we can use this method :

string name = xx.NullSafeEval(x => x.Foo.Bar.Baz.Name, "Default");

Much clearer and concise than our initial code, isn’t it ? 🙂

Note that the proposed implementation handles not only properties, but also method calls, so we could write something like that :

string name = xx.NullSafeEval(x => x.Foo.GetBar(42).Baz.Name, "Default");

Indexers are not handled yet, but they could be added quite easily ; I will leave it to you to do it if you have the use for it 😉

Limitations

Even though that solution can seem very interesting at first sight, please read what follows before you integrate this code into a real world program…

  • First, the proposed code is just a proof of concept, and as such, hasn’t been thoroughly tested, so it’s probably not very reliable.
  • Secondly, keep in mind that dynamic code generation from an expression tree is tough work for the CLR, and will have a big impact on performance. A quick test shows that using the NullSafeEval method is about 10000 times slower than accessing the property directly…

    A possible approach to limit that issue would be to cache the delegates generated for each expression, to avoid regenerating them every time. Unfortunately, as far as I know there is no simple and reliable way to compare two Linq expressions, which makes it much harder to implement such a cache.

  • Last, you might have noticed that intermediate properties and methods are evaluated several times ; not only this is bad for performance, but more importantly, it could have side effects that are hard to predict, depending on how the properties and methods are implemented.

    A possible workaround would be to rewrite the conditional expression as follows :

    Foo foo = null;
    Bar bar = null;
    Baz baz = null;
    var name =
        (x == null)
        ? defaultValue
        : ((foo = x.Foo) == null)
          ? defaultValue
          : ((bar = foo.Bar) == null)
            ? defaultValue
            : ((baz = bar.Baz) == null)
              ? defaultValue
              : baz.Name;
    

    Unfortunately, this is not possible in .NET 3.5 : that version only supports simple expressions, so it’s not possible to declare variables, assign values to them, or write several distinct instructions. However, in .NET 4.0, support for Linq expressions has been largely improved, and makes it possible to generate that kind of code. I’m currently trying to improve the NullSafeEval method to take advantage of the new .NET 4.0 features, but it turns out to be much more difficult than I had anticipated… If I manage to work it out, I’ll let you know and post the code !

To conclude, I wouldn’t recommend using that technique in real programs, at least not in its current state. However, it gives an interesting insight on the possibilities offered by Linq expressions. If you’re new to this, you should know that Linq expressions are used (among other things) :

  • To generate SQL queries in ORMs like Linq to SQL or Entity Framework
  • To build complex predicates dynamically, like in the PredicateBuilder class by Joseph Albahari
  • To implement “static reflection”, which has generated a lot of buzz on technical blogs lately