diff --git a/src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs b/src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs index 0c6e1f5b..539d0b2c 100644 --- a/src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs +++ b/src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs @@ -5,6 +5,7 @@ using MediatR; using MediatR.NotificationPublishers; using MediatR.Pipeline; +using MediatR.Registration; namespace Microsoft.Extensions.DependencyInjection; @@ -133,15 +134,14 @@ public MediatRServiceConfiguration AddBehavior(ServiceLifet /// This public MediatRServiceConfiguration AddBehavior(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) { - var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition()); - var implementedBehaviorTypes = new HashSet(implementedGenericInterfaces.Where(i => i == typeof(IPipelineBehavior<,>))); + var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IPipelineBehavior<,>)).ToList(); - if (implementedBehaviorTypes.Count == 0) + if (implementedGenericInterfaces.Count == 0) { throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IPipelineBehavior<,>).FullName}"); } - foreach (var implementedBehaviorType in implementedBehaviorTypes) + foreach (var implementedBehaviorType in implementedGenericInterfaces) { BehaviorsToRegister.Add(new ServiceDescriptor(implementedBehaviorType, implementationType, serviceLifetime)); } @@ -233,15 +233,14 @@ public MediatRServiceConfiguration AddStreamBehavior(Servic /// This public MediatRServiceConfiguration AddStreamBehavior(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) { - var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition()); - var implementedBehaviorTypes = new HashSet(implementedGenericInterfaces.Where(i => i == typeof(IStreamPipelineBehavior<,>))); + var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IStreamPipelineBehavior<,>)).ToList(); - if (implementedBehaviorTypes.Count == 0) + if (implementedGenericInterfaces.Count == 0) { throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IStreamPipelineBehavior<,>).FullName}"); } - foreach (var implementedBehaviorType in implementedBehaviorTypes) + foreach (var implementedBehaviorType in implementedGenericInterfaces) { StreamBehaviorsToRegister.Add(new ServiceDescriptor(implementedBehaviorType, implementationType, serviceLifetime)); } @@ -320,15 +319,14 @@ public MediatRServiceConfiguration AddRequestPreProcessor( /// This public MediatRServiceConfiguration AddRequestPreProcessor(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) { - var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition()); - var implementedPreProcessorTypes = new HashSet(implementedGenericInterfaces.Where(i => i == typeof(IRequestPreProcessor<>))); + var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IRequestPreProcessor<>)).ToList(); - if (implementedPreProcessorTypes.Count == 0) + if (implementedGenericInterfaces.Count == 0) { throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IRequestPreProcessor<>).FullName}"); } - foreach (var implementedPreProcessorType in implementedPreProcessorTypes) + foreach (var implementedPreProcessorType in implementedGenericInterfaces) { RequestPreProcessorsToRegister.Add(new ServiceDescriptor(implementedPreProcessorType, implementationType, serviceLifetime)); } @@ -406,15 +404,14 @@ public MediatRServiceConfiguration AddRequestPostProcessor( /// This public MediatRServiceConfiguration AddRequestPostProcessor(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) { - var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition()); - var implementedPostProcessorTypes = new HashSet(implementedGenericInterfaces.Where(i => i == typeof(IRequestPostProcessor<,>))); + var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IRequestPostProcessor<,>)).ToList(); - if (implementedPostProcessorTypes.Count == 0) + if (implementedGenericInterfaces.Count == 0) { throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IRequestPostProcessor<,>).FullName}"); } - foreach (var implementedPostProcessorType in implementedPostProcessorTypes) + foreach (var implementedPostProcessorType in implementedGenericInterfaces) { RequestPostProcessorsToRegister.Add(new ServiceDescriptor(implementedPostProcessorType, implementationType, serviceLifetime)); } diff --git a/src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs b/src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs index c95b5a8f..50e9787f 100644 --- a/src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs +++ b/src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs @@ -30,14 +30,26 @@ public static IServiceCollection AddMediatR(this IServiceCollection services, configuration.Invoke(serviceConfig); - if (!serviceConfig.AssembliesToRegister.Any()) + return services.AddMediatR(serviceConfig); + } + + /// + /// Registers handlers and mediator types from the specified assemblies + /// + /// Service collection + /// Configuration options + /// Service collection + public static IServiceCollection AddMediatR(this IServiceCollection services, + MediatRServiceConfiguration configuration) + { + if (!configuration.AssembliesToRegister.Any()) { throw new ArgumentException("No assemblies found to scan. Supply at least one assembly to scan for handlers."); } - ServiceRegistrar.AddMediatRClasses(services, serviceConfig); + ServiceRegistrar.AddMediatRClasses(services, configuration); - ServiceRegistrar.AddRequiredServices(services, serviceConfig); + ServiceRegistrar.AddRequiredServices(services, configuration); return services; } diff --git a/src/MediatR/Registration/ServiceRegistrar.cs b/src/MediatR/Registration/ServiceRegistrar.cs index 6d1cdd39..0d69e158 100644 --- a/src/MediatR/Registration/ServiceRegistrar.cs +++ b/src/MediatR/Registration/ServiceRegistrar.cs @@ -138,7 +138,7 @@ private static void AddConcretionsThatCouldBeClosed(Type @interface, List } } - private static bool CouldCloseTo(this Type openConcretion, Type closedInterface) + internal static bool CouldCloseTo(this Type openConcretion, Type closedInterface) { var openInterface = closedInterface.GetGenericTypeDefinition(); var arguments = closedInterface.GenericTypeArguments; @@ -161,7 +161,7 @@ private static bool IsOpenGeneric(this Type type) return type.IsGenericTypeDefinition || type.ContainsGenericParameters; } - private static IEnumerable FindInterfacesThatClose(this Type pluggedType, Type templateType) + internal static IEnumerable FindInterfacesThatClose(this Type pluggedType, Type templateType) { return FindInterfacesThatClosesCore(pluggedType, templateType).Distinct(); } @@ -221,6 +221,17 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi services.TryAdd(notificationPublisherServiceDescriptor); // Register pre processors, then post processors, then behaviors + if (serviceConfiguration.RequestExceptionActionProcessorStrategy == RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions) + { + RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>)); + RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>), typeof(IRequestExceptionHandler<,,>)); + } + else + { + RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>), typeof(IRequestExceptionHandler<,,>)); + RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>)); + } + if (serviceConfiguration.RequestPreProcessorsToRegister.Any()) { services.TryAddEnumerable(new ServiceDescriptor(typeof(IPipelineBehavior<,>), typeof(RequestPreProcessorBehavior<,>), ServiceLifetime.Transient)); @@ -242,17 +253,6 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi { services.TryAddEnumerable(serviceDescriptor); } - - if (serviceConfiguration.RequestExceptionActionProcessorStrategy == RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions) - { - RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>)); - RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>), typeof(IRequestExceptionHandler<,,>)); - } - else - { - RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>), typeof(IRequestExceptionHandler<,,>)); - RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>)); - } } private static void RegisterBehaviorIfImplementationsExist(IServiceCollection services, Type behaviorType, Type subBehaviorType) diff --git a/test/MediatR.Tests/MicrosoftExtensionsDI/PipelineTests.cs b/test/MediatR.Tests/MicrosoftExtensionsDI/PipelineTests.cs index 28c4ec78..1d231e83 100644 --- a/test/MediatR.Tests/MicrosoftExtensionsDI/PipelineTests.cs +++ b/test/MediatR.Tests/MicrosoftExtensionsDI/PipelineTests.cs @@ -361,6 +361,11 @@ public class NotAnOpenBehavior : IPipelineBehavior public Task Handle(Ping request, RequestHandlerDelegate next, CancellationToken cancellationToken) => next(); } + public class ThrowingBehavior : IPipelineBehavior + { + public Task Handle(Ping request, RequestHandlerDelegate next, CancellationToken cancellationToken) => throw new Exception(request.Message); + } + public class NotAnOpenStreamBehavior : IStreamPipelineBehavior { public IAsyncEnumerable Handle(Ping request, StreamHandlerDelegate next, CancellationToken cancellationToken) => next(); @@ -524,6 +529,27 @@ public void Should_pick_up_base_exception_behaviors() output.Messages.ShouldContain("Logging generic exception"); } + [Fact] + public void Should_handle_exceptions_from_behaviors() + { + var output = new Logger(); + IServiceCollection services = new ServiceCollection(); + services.AddSingleton(output); + services.AddMediatR(cfg => + { + cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly); + cfg.AddBehavior(); + }); + var provider = services.BuildServiceProvider(); + + var mediator = provider.GetRequiredService(); + + Should.Throw(async () => await mediator.Send(new Ping {Message = "Ping"})); + + output.Messages.ShouldContain("Ping Logged by Generic Type"); + output.Messages.ShouldContain("Logging generic exception"); + } + [Fact] public void Should_pick_up_exception_actions() { @@ -648,6 +674,16 @@ public void Should_handle_open_behavior_registration() cfg.StreamBehaviorsToRegister[0].ImplementationFactory.ShouldBeNull(); cfg.StreamBehaviorsToRegister[0].ImplementationInstance.ShouldBeNull(); cfg.StreamBehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient); + + var services = new ServiceCollection(); + + cfg.RegisterServicesFromAssemblyContaining(); + + Should.NotThrow(() => + { + services.AddMediatR(cfg); + services.BuildServiceProvider(); + }); } [Fact] @@ -659,16 +695,26 @@ public void Should_handle_inferred_behavior_registration() cfg.BehaviorsToRegister.Count.ShouldBe(2); - cfg.BehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IPipelineBehavior<,>)); + cfg.BehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IPipelineBehavior)); cfg.BehaviorsToRegister[0].ImplementationType.ShouldBe(typeof(InnerBehavior)); cfg.BehaviorsToRegister[0].ImplementationFactory.ShouldBeNull(); cfg.BehaviorsToRegister[0].ImplementationInstance.ShouldBeNull(); cfg.BehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient); - cfg.BehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IPipelineBehavior<,>)); + cfg.BehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IPipelineBehavior)); cfg.BehaviorsToRegister[1].ImplementationType.ShouldBe(typeof(OuterBehavior)); cfg.BehaviorsToRegister[1].ImplementationFactory.ShouldBeNull(); cfg.BehaviorsToRegister[1].ImplementationInstance.ShouldBeNull(); cfg.BehaviorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient); + + var services = new ServiceCollection(); + + cfg.RegisterServicesFromAssemblyContaining(); + + Should.NotThrow(() => + { + services.AddMediatR(cfg); + services.BuildServiceProvider(); + }); } @@ -681,16 +727,26 @@ public void Should_handle_inferred_stream_behavior_registration() cfg.StreamBehaviorsToRegister.Count.ShouldBe(2); - cfg.StreamBehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior<,>)); + cfg.StreamBehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior)); cfg.StreamBehaviorsToRegister[0].ImplementationType.ShouldBe(typeof(InnerStreamBehavior)); cfg.StreamBehaviorsToRegister[0].ImplementationFactory.ShouldBeNull(); cfg.StreamBehaviorsToRegister[0].ImplementationInstance.ShouldBeNull(); cfg.StreamBehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient); - cfg.StreamBehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior<,>)); + cfg.StreamBehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior)); cfg.StreamBehaviorsToRegister[1].ImplementationType.ShouldBe(typeof(OuterStreamBehavior)); cfg.StreamBehaviorsToRegister[1].ImplementationFactory.ShouldBeNull(); cfg.StreamBehaviorsToRegister[1].ImplementationInstance.ShouldBeNull(); cfg.StreamBehaviorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient); + + var services = new ServiceCollection(); + + cfg.RegisterServicesFromAssemblyContaining(); + + Should.NotThrow(() => + { + services.AddMediatR(cfg); + services.BuildServiceProvider(); + }); } [Fact] @@ -702,16 +758,26 @@ public void Should_handle_inferred_pre_processor_registration() cfg.RequestPreProcessorsToRegister.Count.ShouldBe(2); - cfg.RequestPreProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPreProcessor<>)); + cfg.RequestPreProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPreProcessor)); cfg.RequestPreProcessorsToRegister[0].ImplementationType.ShouldBe(typeof(FirstConcretePreProcessor)); cfg.RequestPreProcessorsToRegister[0].ImplementationFactory.ShouldBeNull(); cfg.RequestPreProcessorsToRegister[0].ImplementationInstance.ShouldBeNull(); cfg.RequestPreProcessorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient); - cfg.RequestPreProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPreProcessor<>)); + cfg.RequestPreProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPreProcessor)); cfg.RequestPreProcessorsToRegister[1].ImplementationType.ShouldBe(typeof(NextConcretePreProcessor)); cfg.RequestPreProcessorsToRegister[1].ImplementationFactory.ShouldBeNull(); cfg.RequestPreProcessorsToRegister[1].ImplementationInstance.ShouldBeNull(); cfg.RequestPreProcessorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient); + + var services = new ServiceCollection(); + + cfg.RegisterServicesFromAssemblyContaining(); + + Should.NotThrow(() => + { + services.AddMediatR(cfg); + services.BuildServiceProvider(); + }); } [Fact] @@ -723,16 +789,26 @@ public void Should_handle_inferred_post_processor_registration() cfg.RequestPostProcessorsToRegister.Count.ShouldBe(2); - cfg.RequestPostProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPostProcessor<,>)); + cfg.RequestPostProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPostProcessor)); cfg.RequestPostProcessorsToRegister[0].ImplementationType.ShouldBe(typeof(FirstConcretePostProcessor)); cfg.RequestPostProcessorsToRegister[0].ImplementationFactory.ShouldBeNull(); cfg.RequestPostProcessorsToRegister[0].ImplementationInstance.ShouldBeNull(); cfg.RequestPostProcessorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient); - cfg.RequestPostProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPostProcessor<,>)); + cfg.RequestPostProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPostProcessor)); cfg.RequestPostProcessorsToRegister[1].ImplementationType.ShouldBe(typeof(NextConcretePostProcessor)); cfg.RequestPostProcessorsToRegister[1].ImplementationFactory.ShouldBeNull(); cfg.RequestPostProcessorsToRegister[1].ImplementationInstance.ShouldBeNull(); cfg.RequestPostProcessorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient); + + var services = new ServiceCollection(); + + cfg.RegisterServicesFromAssemblyContaining(); + + Should.NotThrow(() => + { + services.AddMediatR(cfg); + services.BuildServiceProvider(); + }); } [Fact] @@ -756,5 +832,15 @@ public void Should_handle_open_behaviors_registration_from_a_single_type() cfg.StreamBehaviorsToRegister[0].ImplementationFactory.ShouldBeNull(); cfg.StreamBehaviorsToRegister[0].ImplementationInstance.ShouldBeNull(); cfg.StreamBehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Singleton); + + var services = new ServiceCollection(); + + cfg.RegisterServicesFromAssemblyContaining(); + + Should.NotThrow(() => + { + services.AddMediatR(cfg); + services.BuildServiceProvider(); + }); } } \ No newline at end of file