Simple Dependency Injection

with C#

level: 200

Background 🧪

I wrote an extension to xUnit.net called Mettle to enable dependency injection (“DI”) for constructors and test methods to somewhat mimic jUnit’s DI feature and various JavaScript unit test framework features where the asset class can be injected into the test method.

Rather than tying Mettle or xUnit.net to a specific DI library, I created a simple default DI implementation. The default implementation is written against some basic interfaces that leverage IServiceProvider, which is a core interface shipped with .NET.

The interfaces allow other dependency frameworks to replace the default DI implementation when using Mettle.

The code below is the result of that default implementation. It can be used to understand the basic mechanics of how DI is implemented in various libraries.

DI Basics 💡

DI exists to enable the handle resolution of dependencies when creating new instances of classes and to enable Inversion of Control (“IoC”). IoC allows developers to use abstractions that can be switch between different implementations during a program’s runtime by using dynamic binding.

Instead of:

var instance = new Controller(new Logger(), new DbContext()); 


You’ll see something like:

var instance = container.Resolve(typeof(Controller);


Dependency Injection libraries ship a container class that acts as a special object factory. When the container class creates a new object, it can locate registered dependencies by type and inject a new object with the required instances.

Sometimes the dependency has to be created for each request, sometimes the dependency is a singleton object, and sometimes the dependency is only created once within a certain lifetime such as an HTTP request.

C# DI libraries require you to register types during the application startup and build the container. Once the container is built against the registered dependencies, you can then use it to create objects and resolve dependencies.

Some C# DI libraries will implement the IServiceProvider interface on the container class. The interface requires a single method of GetService(Type type).

Simple DI Implementation 👨‍💻

The code below is a simple DI container implementation. The name of the container class in the sample is SimpleServiceProvider. It supports the following styles of object creation:

  • transient: a new object per request to resolve a dependency.

  • scoped: a new object per the scope of a given lifetime container. Only one object instance for the given type is allowed during the life of that container. Once the container is closed, the instance is removed.

  • singleton: an object that only allowed to have one single instance for the lifetime of the program.

Basics of using a lifetime scope

IContract scopedInstance = null;

// create a lifetime scope. 
using(var scope = (IScopedServiceProviderLifetime)container.GetType(typeof(IScopedServiceProviderLifetime)))
{
    var scopedContainer = scope.ServiceProvider;
    scopedInstance = scopedContainer.GetService(typeof(IContract));
    var scopedInstance2 = scopedContainer.GetService(typeof(IContract));

    // will return true
    Console.WriteLine(Object.ReferenceEquals(scopedInstance, scopedInstance2);
}

var scopedInstance3 = scope.GetService(typeof(IContract));
// will return false because scopedInstance3 was created
// outside the of the same lifetime scope.
Console.WriteLine(Object.ReferenceEquals(scopedInstance, scopedInstance3);

The Code

    public interface IScopedServiceProviderLifetime : IDisposable
    {
        IServiceProvider Provider { get; }
    }

    public interface IServiceProviderFactory
    {
        IServiceProvider CreateProvider();
    }


    public class SimpleServiceProvider : IServiceProvider, IDisposable
    {
        private ConcurrentDictionary<Type, Func<IServiceProvider, object>> factories =
            new();

        private ScopedLifetime scopedLifetime;

        public SimpleServiceProvider()
        {
            this.scopedLifetime = new ScopedLifetime();
            this.factories.TryAdd(typeof(IAssert), s => AssertImpl.Current);
            this.factories.TryAdd(typeof(ScopedLifetime), s => this.scopedLifetime);
            this.factories.TryAdd(typeof(IScopedServiceProviderLifetime), s => new SimpleScopedServiceLifetime(this));
            this.AddScoped(typeof(ITestOutputHelper), s => new TestOutputHelper());
        }

        /// <summary>Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.</summary>
        public void Dispose()
        {
            if (this.scopedLifetime == null)
                return;

            foreach (var disposable in this.scopedLifetime.GetDisposables())
                disposable.Dispose();

            this.scopedLifetime?.Clear();
            this.factories?.Clear();
            this.scopedLifetime = null;
            this.factories = null;
        }

        public object GetService(Type type)
        {
            if (type == null)
                throw new ArgumentNullException(nameof(type));

            if (this.factories.TryGetValue(type, out var factory))
                return factory(this);

            if (!type.IsValueType)
                return Activator.CreateInstance(type);

            return null;
        }

        public void AddSingleton(Type type, object instance)
        {
            if (type == null)
                throw new ArgumentNullException(nameof(type));

            this.scopedLifetime.SetState(type, instance);
            this.factories.TryAdd(type, s => instance);
        }

        public void AddSingleton(Type type, Func<IServiceProvider, object> activator)
        {
            if (type == null)
                throw new ArgumentNullException(nameof(type));

            this.AddScoped(type, activator);
        }

        public void AddScoped(Type type, Func<IServiceProvider, object> activator)
        {
            this.factories.TryAdd(type, s =>
            {
                var sl = s.GetService(typeof(ScopedLifetime));
                if (sl == null)
                    return null;

                var scope = (ScopedLifetime)sl;
                if (scope.ContainsKey(type))
                    return scope.GetState(type);

                var r = activator(s);
                scope.SetState(type, r);
                return r;
            });
        }

        public void AddTransient(Type type)
        {
            this.factories.TryAdd(type, s => Activator.CreateInstance(type));
        }

        public void AddTransient(Type type, Func<IServiceProvider, object> activator)
        {
            if (type == null)
                throw new ArgumentNullException(nameof(type));

            if (activator == null)
                throw new ArgumentNullException(nameof(activator));

            this.factories.TryAdd(type, activator);
        }

        public class ScopedLifetime
        {
            private readonly ConcurrentDictionary<Type, object> state = new();

            public bool ContainsKey(Type type)
            {
                return this.state.ContainsKey(type);
            }

            public void SetState(Type type, object instance)
            {
                this.state[type] = instance;
            }

            public object GetState(Type type)
            {
                this.state.TryGetValue(type, out var instance);
                return instance;
            }

            public void Clear()
            {
                this.state.Clear();
            }

            public IEnumerable<IDisposable> GetDisposables()
            {
                var list = new List<IDisposable>();
                foreach (var kv in this.state)
                    if (kv.Value is IDisposable disposable)
                        list.Add(disposable);

                return list;
            }
        }

        private class SimpleScopedServiceLifetime : IScopedServiceProviderLifetime
        {
            private SimpleServiceProvider provider;

            public SimpleScopedServiceLifetime(SimpleServiceProvider provider)
            {
                var provider2 = new SimpleServiceProvider();
                foreach (var kv in provider.factories)
                {
                    if (kv.Key == typeof(ScopedLifetime))
                        continue;

                    if (provider2.factories.ContainsKey(kv.Key))
                        continue;

                    provider2.factories.TryAdd(kv.Key, kv.Value);
                }

                this.provider = provider;
            }

            public IServiceProvider Provider => this.provider;

            /// <summary>Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.</summary>
            public void Dispose()
            {
                this.provider?.Dispose();
                this.provider = null;
            }
        }