Skip to content

Commit

Permalink
Merge pull request #925 from jbogard/fixing-behavior-registration
Browse files Browse the repository at this point in the history
Fixing registration problem; moving exception behaviors first so that…
  • Loading branch information
jbogard committed Jul 10, 2023
2 parents 7fe73da + f28cdc3 commit c295291
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 40 deletions.
29 changes: 13 additions & 16 deletions src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using MediatR;
using MediatR.NotificationPublishers;
using MediatR.Pipeline;
using MediatR.Registration;

namespace Microsoft.Extensions.DependencyInjection;

Expand Down Expand Up @@ -133,15 +134,14 @@ public MediatRServiceConfiguration AddBehavior<TImplementationType>(ServiceLifet
/// <returns>This</returns>
public MediatRServiceConfiguration AddBehavior(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedBehaviorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IPipelineBehavior<,>)));
var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IPipelineBehavior<,>)).ToList();

if (implementedBehaviorTypes.Count == 0)
if (implementedGenericInterfaces.Count == 0)
{
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IPipelineBehavior<,>).FullName}");
}

foreach (var implementedBehaviorType in implementedBehaviorTypes)
foreach (var implementedBehaviorType in implementedGenericInterfaces)
{
BehaviorsToRegister.Add(new ServiceDescriptor(implementedBehaviorType, implementationType, serviceLifetime));
}
Expand Down Expand Up @@ -233,15 +233,14 @@ public MediatRServiceConfiguration AddStreamBehavior<TImplementationType>(Servic
/// <returns>This</returns>
public MediatRServiceConfiguration AddStreamBehavior(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedBehaviorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IStreamPipelineBehavior<,>)));
var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IStreamPipelineBehavior<,>)).ToList();

if (implementedBehaviorTypes.Count == 0)
if (implementedGenericInterfaces.Count == 0)
{
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IStreamPipelineBehavior<,>).FullName}");
}

foreach (var implementedBehaviorType in implementedBehaviorTypes)
foreach (var implementedBehaviorType in implementedGenericInterfaces)
{
StreamBehaviorsToRegister.Add(new ServiceDescriptor(implementedBehaviorType, implementationType, serviceLifetime));
}
Expand Down Expand Up @@ -320,15 +319,14 @@ public MediatRServiceConfiguration AddRequestPreProcessor<TImplementationType>(
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPreProcessor(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedPreProcessorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IRequestPreProcessor<>)));
var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IRequestPreProcessor<>)).ToList();

if (implementedPreProcessorTypes.Count == 0)
if (implementedGenericInterfaces.Count == 0)
{
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IRequestPreProcessor<>).FullName}");
}

foreach (var implementedPreProcessorType in implementedPreProcessorTypes)
foreach (var implementedPreProcessorType in implementedGenericInterfaces)
{
RequestPreProcessorsToRegister.Add(new ServiceDescriptor(implementedPreProcessorType, implementationType, serviceLifetime));
}
Expand Down Expand Up @@ -406,15 +404,14 @@ public MediatRServiceConfiguration AddRequestPostProcessor<TImplementationType>(
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPostProcessor(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedPostProcessorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IRequestPostProcessor<,>)));
var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IRequestPostProcessor<,>)).ToList();

if (implementedPostProcessorTypes.Count == 0)
if (implementedGenericInterfaces.Count == 0)
{
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IRequestPostProcessor<,>).FullName}");
}

foreach (var implementedPostProcessorType in implementedPostProcessorTypes)
foreach (var implementedPostProcessorType in implementedGenericInterfaces)
{
RequestPostProcessorsToRegister.Add(new ServiceDescriptor(implementedPostProcessorType, implementationType, serviceLifetime));
}
Expand Down
18 changes: 15 additions & 3 deletions src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,26 @@ public static IServiceCollection AddMediatR(this IServiceCollection services,

configuration.Invoke(serviceConfig);

if (!serviceConfig.AssembliesToRegister.Any())
return services.AddMediatR(serviceConfig);
}

/// <summary>
/// Registers handlers and mediator types from the specified assemblies
/// </summary>
/// <param name="services">Service collection</param>
/// <param name="configuration">Configuration options</param>
/// <returns>Service collection</returns>
public static IServiceCollection AddMediatR(this IServiceCollection services,
MediatRServiceConfiguration configuration)
{
if (!configuration.AssembliesToRegister.Any())
{
throw new ArgumentException("No assemblies found to scan. Supply at least one assembly to scan for handlers.");
}

ServiceRegistrar.AddMediatRClasses(services, serviceConfig);
ServiceRegistrar.AddMediatRClasses(services, configuration);

ServiceRegistrar.AddRequiredServices(services, serviceConfig);
ServiceRegistrar.AddRequiredServices(services, configuration);

return services;
}
Expand Down
26 changes: 13 additions & 13 deletions src/MediatR/Registration/ServiceRegistrar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ private static void AddConcretionsThatCouldBeClosed(Type @interface, List<Type>
}
}

private static bool CouldCloseTo(this Type openConcretion, Type closedInterface)
internal static bool CouldCloseTo(this Type openConcretion, Type closedInterface)
{
var openInterface = closedInterface.GetGenericTypeDefinition();
var arguments = closedInterface.GenericTypeArguments;
Expand All @@ -161,7 +161,7 @@ private static bool IsOpenGeneric(this Type type)
return type.IsGenericTypeDefinition || type.ContainsGenericParameters;
}

private static IEnumerable<Type> FindInterfacesThatClose(this Type pluggedType, Type templateType)
internal static IEnumerable<Type> FindInterfacesThatClose(this Type pluggedType, Type templateType)
{
return FindInterfacesThatClosesCore(pluggedType, templateType).Distinct();
}
Expand Down Expand Up @@ -221,6 +221,17 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi
services.TryAdd(notificationPublisherServiceDescriptor);

// Register pre processors, then post processors, then behaviors
if (serviceConfiguration.RequestExceptionActionProcessorStrategy == RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions)
{
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>), typeof(IRequestExceptionHandler<,,>));
}
else
{
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>), typeof(IRequestExceptionHandler<,,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>));
}

if (serviceConfiguration.RequestPreProcessorsToRegister.Any())
{
services.TryAddEnumerable(new ServiceDescriptor(typeof(IPipelineBehavior<,>), typeof(RequestPreProcessorBehavior<,>), ServiceLifetime.Transient));
Expand All @@ -242,17 +253,6 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi
{
services.TryAddEnumerable(serviceDescriptor);
}

if (serviceConfiguration.RequestExceptionActionProcessorStrategy == RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions)
{
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>), typeof(IRequestExceptionHandler<,,>));
}
else
{
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>), typeof(IRequestExceptionHandler<,,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>));
}
}

private static void RegisterBehaviorIfImplementationsExist(IServiceCollection services, Type behaviorType, Type subBehaviorType)
Expand Down
102 changes: 94 additions & 8 deletions test/MediatR.Tests/MicrosoftExtensionsDI/PipelineTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,11 @@ public class NotAnOpenBehavior : IPipelineBehavior<Ping, Pong>
public Task<Pong> Handle(Ping request, RequestHandlerDelegate<Pong> next, CancellationToken cancellationToken) => next();
}

public class ThrowingBehavior : IPipelineBehavior<Ping, Pong>
{
public Task<Pong> Handle(Ping request, RequestHandlerDelegate<Pong> next, CancellationToken cancellationToken) => throw new Exception(request.Message);
}

public class NotAnOpenStreamBehavior : IStreamPipelineBehavior<Ping, Pong>
{
public IAsyncEnumerable<Pong> Handle(Ping request, StreamHandlerDelegate<Pong> next, CancellationToken cancellationToken) => next();
Expand Down Expand Up @@ -524,6 +529,27 @@ public void Should_pick_up_base_exception_behaviors()
output.Messages.ShouldContain("Logging generic exception");
}

[Fact]
public void Should_handle_exceptions_from_behaviors()
{
var output = new Logger();
IServiceCollection services = new ServiceCollection();
services.AddSingleton(output);
services.AddMediatR(cfg =>
{
cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly);
cfg.AddBehavior<ThrowingBehavior>();
});
var provider = services.BuildServiceProvider();

var mediator = provider.GetRequiredService<IMediator>();

Should.Throw<Exception>(async () => await mediator.Send(new Ping {Message = "Ping"}));

output.Messages.ShouldContain("Ping Logged by Generic Type");
output.Messages.ShouldContain("Logging generic exception");
}

[Fact]
public void Should_pick_up_exception_actions()
{
Expand Down Expand Up @@ -648,6 +674,16 @@ public void Should_handle_open_behavior_registration()
cfg.StreamBehaviorsToRegister[0].ImplementationFactory.ShouldBeNull();
cfg.StreamBehaviorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.StreamBehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);

var services = new ServiceCollection();

cfg.RegisterServicesFromAssemblyContaining<Ping>();

Should.NotThrow(() =>
{
services.AddMediatR(cfg);
services.BuildServiceProvider();
});
}

[Fact]
Expand All @@ -659,16 +695,26 @@ public void Should_handle_inferred_behavior_registration()

cfg.BehaviorsToRegister.Count.ShouldBe(2);

cfg.BehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IPipelineBehavior<,>));
cfg.BehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IPipelineBehavior<Ping, Pong>));
cfg.BehaviorsToRegister[0].ImplementationType.ShouldBe(typeof(InnerBehavior));
cfg.BehaviorsToRegister[0].ImplementationFactory.ShouldBeNull();
cfg.BehaviorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.BehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
cfg.BehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IPipelineBehavior<,>));
cfg.BehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IPipelineBehavior<Ping, Pong>));
cfg.BehaviorsToRegister[1].ImplementationType.ShouldBe(typeof(OuterBehavior));
cfg.BehaviorsToRegister[1].ImplementationFactory.ShouldBeNull();
cfg.BehaviorsToRegister[1].ImplementationInstance.ShouldBeNull();
cfg.BehaviorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient);

var services = new ServiceCollection();

cfg.RegisterServicesFromAssemblyContaining<Ping>();

Should.NotThrow(() =>
{
services.AddMediatR(cfg);
services.BuildServiceProvider();
});
}


Expand All @@ -681,16 +727,26 @@ public void Should_handle_inferred_stream_behavior_registration()

cfg.StreamBehaviorsToRegister.Count.ShouldBe(2);

cfg.StreamBehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior<,>));
cfg.StreamBehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior<Ping, Pong>));
cfg.StreamBehaviorsToRegister[0].ImplementationType.ShouldBe(typeof(InnerStreamBehavior));
cfg.StreamBehaviorsToRegister[0].ImplementationFactory.ShouldBeNull();
cfg.StreamBehaviorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.StreamBehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
cfg.StreamBehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior<,>));
cfg.StreamBehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior<Ping, Pong>));
cfg.StreamBehaviorsToRegister[1].ImplementationType.ShouldBe(typeof(OuterStreamBehavior));
cfg.StreamBehaviorsToRegister[1].ImplementationFactory.ShouldBeNull();
cfg.StreamBehaviorsToRegister[1].ImplementationInstance.ShouldBeNull();
cfg.StreamBehaviorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient);

var services = new ServiceCollection();

cfg.RegisterServicesFromAssemblyContaining<Ping>();

Should.NotThrow(() =>
{
services.AddMediatR(cfg);
services.BuildServiceProvider();
});
}

[Fact]
Expand All @@ -702,16 +758,26 @@ public void Should_handle_inferred_pre_processor_registration()

cfg.RequestPreProcessorsToRegister.Count.ShouldBe(2);

cfg.RequestPreProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPreProcessor<>));
cfg.RequestPreProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPreProcessor<Ping>));
cfg.RequestPreProcessorsToRegister[0].ImplementationType.ShouldBe(typeof(FirstConcretePreProcessor));
cfg.RequestPreProcessorsToRegister[0].ImplementationFactory.ShouldBeNull();
cfg.RequestPreProcessorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.RequestPreProcessorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
cfg.RequestPreProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPreProcessor<>));
cfg.RequestPreProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPreProcessor<Ping>));
cfg.RequestPreProcessorsToRegister[1].ImplementationType.ShouldBe(typeof(NextConcretePreProcessor));
cfg.RequestPreProcessorsToRegister[1].ImplementationFactory.ShouldBeNull();
cfg.RequestPreProcessorsToRegister[1].ImplementationInstance.ShouldBeNull();
cfg.RequestPreProcessorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient);

var services = new ServiceCollection();

cfg.RegisterServicesFromAssemblyContaining<Ping>();

Should.NotThrow(() =>
{
services.AddMediatR(cfg);
services.BuildServiceProvider();
});
}

[Fact]
Expand All @@ -723,16 +789,26 @@ public void Should_handle_inferred_post_processor_registration()

cfg.RequestPostProcessorsToRegister.Count.ShouldBe(2);

cfg.RequestPostProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPostProcessor<,>));
cfg.RequestPostProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPostProcessor<Ping, Pong>));
cfg.RequestPostProcessorsToRegister[0].ImplementationType.ShouldBe(typeof(FirstConcretePostProcessor));
cfg.RequestPostProcessorsToRegister[0].ImplementationFactory.ShouldBeNull();
cfg.RequestPostProcessorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.RequestPostProcessorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
cfg.RequestPostProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPostProcessor<,>));
cfg.RequestPostProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPostProcessor<Ping, Pong>));
cfg.RequestPostProcessorsToRegister[1].ImplementationType.ShouldBe(typeof(NextConcretePostProcessor));
cfg.RequestPostProcessorsToRegister[1].ImplementationFactory.ShouldBeNull();
cfg.RequestPostProcessorsToRegister[1].ImplementationInstance.ShouldBeNull();
cfg.RequestPostProcessorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient);

var services = new ServiceCollection();

cfg.RegisterServicesFromAssemblyContaining<Ping>();

Should.NotThrow(() =>
{
services.AddMediatR(cfg);
services.BuildServiceProvider();
});
}

[Fact]
Expand All @@ -756,5 +832,15 @@ public void Should_handle_open_behaviors_registration_from_a_single_type()
cfg.StreamBehaviorsToRegister[0].ImplementationFactory.ShouldBeNull();
cfg.StreamBehaviorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.StreamBehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Singleton);

var services = new ServiceCollection();

cfg.RegisterServicesFromAssemblyContaining<Ping>();

Should.NotThrow(() =>
{
services.AddMediatR(cfg);
services.BuildServiceProvider();
});
}
}

0 comments on commit c295291

Please sign in to comment.