Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing registration problem; moving exception behaviors first so that… #925

Merged
merged 1 commit into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using MediatR;
using MediatR.NotificationPublishers;
using MediatR.Pipeline;
using MediatR.Registration;

namespace Microsoft.Extensions.DependencyInjection;

Expand Down Expand Up @@ -133,15 +134,14 @@ public MediatRServiceConfiguration AddBehavior<TImplementationType>(ServiceLifet
/// <returns>This</returns>
public MediatRServiceConfiguration AddBehavior(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedBehaviorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IPipelineBehavior<,>)));
var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IPipelineBehavior<,>)).ToList();

if (implementedBehaviorTypes.Count == 0)
if (implementedGenericInterfaces.Count == 0)
{
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IPipelineBehavior<,>).FullName}");
}

foreach (var implementedBehaviorType in implementedBehaviorTypes)
foreach (var implementedBehaviorType in implementedGenericInterfaces)
{
BehaviorsToRegister.Add(new ServiceDescriptor(implementedBehaviorType, implementationType, serviceLifetime));
}
Expand Down Expand Up @@ -233,15 +233,14 @@ public MediatRServiceConfiguration AddStreamBehavior<TImplementationType>(Servic
/// <returns>This</returns>
public MediatRServiceConfiguration AddStreamBehavior(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedBehaviorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IStreamPipelineBehavior<,>)));
var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IStreamPipelineBehavior<,>)).ToList();

if (implementedBehaviorTypes.Count == 0)
if (implementedGenericInterfaces.Count == 0)
{
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IStreamPipelineBehavior<,>).FullName}");
}

foreach (var implementedBehaviorType in implementedBehaviorTypes)
foreach (var implementedBehaviorType in implementedGenericInterfaces)
{
StreamBehaviorsToRegister.Add(new ServiceDescriptor(implementedBehaviorType, implementationType, serviceLifetime));
}
Expand Down Expand Up @@ -320,15 +319,14 @@ public MediatRServiceConfiguration AddRequestPreProcessor<TImplementationType>(
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPreProcessor(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedPreProcessorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IRequestPreProcessor<>)));
var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IRequestPreProcessor<>)).ToList();

if (implementedPreProcessorTypes.Count == 0)
if (implementedGenericInterfaces.Count == 0)
{
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IRequestPreProcessor<>).FullName}");
}

foreach (var implementedPreProcessorType in implementedPreProcessorTypes)
foreach (var implementedPreProcessorType in implementedGenericInterfaces)
{
RequestPreProcessorsToRegister.Add(new ServiceDescriptor(implementedPreProcessorType, implementationType, serviceLifetime));
}
Expand Down Expand Up @@ -406,15 +404,14 @@ public MediatRServiceConfiguration AddRequestPostProcessor<TImplementationType>(
/// <returns>This</returns>
public MediatRServiceConfiguration AddRequestPostProcessor(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
var implementedPostProcessorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IRequestPostProcessor<,>)));
var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IRequestPostProcessor<,>)).ToList();

if (implementedPostProcessorTypes.Count == 0)
if (implementedGenericInterfaces.Count == 0)
{
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IRequestPostProcessor<,>).FullName}");
}

foreach (var implementedPostProcessorType in implementedPostProcessorTypes)
foreach (var implementedPostProcessorType in implementedGenericInterfaces)
{
RequestPostProcessorsToRegister.Add(new ServiceDescriptor(implementedPostProcessorType, implementationType, serviceLifetime));
}
Expand Down
18 changes: 15 additions & 3 deletions src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,26 @@ public static IServiceCollection AddMediatR(this IServiceCollection services,

configuration.Invoke(serviceConfig);

if (!serviceConfig.AssembliesToRegister.Any())
return services.AddMediatR(serviceConfig);
}

/// <summary>
/// Registers handlers and mediator types from the specified assemblies
/// </summary>
/// <param name="services">Service collection</param>
/// <param name="configuration">Configuration options</param>
/// <returns>Service collection</returns>
public static IServiceCollection AddMediatR(this IServiceCollection services,
MediatRServiceConfiguration configuration)
{
if (!configuration.AssembliesToRegister.Any())
{
throw new ArgumentException("No assemblies found to scan. Supply at least one assembly to scan for handlers.");
}

ServiceRegistrar.AddMediatRClasses(services, serviceConfig);
ServiceRegistrar.AddMediatRClasses(services, configuration);

ServiceRegistrar.AddRequiredServices(services, serviceConfig);
ServiceRegistrar.AddRequiredServices(services, configuration);

return services;
}
Expand Down
26 changes: 13 additions & 13 deletions src/MediatR/Registration/ServiceRegistrar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ private static void AddConcretionsThatCouldBeClosed(Type @interface, List<Type>
}
}

private static bool CouldCloseTo(this Type openConcretion, Type closedInterface)
internal static bool CouldCloseTo(this Type openConcretion, Type closedInterface)
{
var openInterface = closedInterface.GetGenericTypeDefinition();
var arguments = closedInterface.GenericTypeArguments;
Expand All @@ -161,7 +161,7 @@ private static bool IsOpenGeneric(this Type type)
return type.IsGenericTypeDefinition || type.ContainsGenericParameters;
}

private static IEnumerable<Type> FindInterfacesThatClose(this Type pluggedType, Type templateType)
internal static IEnumerable<Type> FindInterfacesThatClose(this Type pluggedType, Type templateType)
{
return FindInterfacesThatClosesCore(pluggedType, templateType).Distinct();
}
Expand Down Expand Up @@ -221,6 +221,17 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi
services.TryAdd(notificationPublisherServiceDescriptor);

// Register pre processors, then post processors, then behaviors
if (serviceConfiguration.RequestExceptionActionProcessorStrategy == RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions)
{
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>), typeof(IRequestExceptionHandler<,,>));
}
else
{
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>), typeof(IRequestExceptionHandler<,,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>));
}

if (serviceConfiguration.RequestPreProcessorsToRegister.Any())
{
services.TryAddEnumerable(new ServiceDescriptor(typeof(IPipelineBehavior<,>), typeof(RequestPreProcessorBehavior<,>), ServiceLifetime.Transient));
Expand All @@ -242,17 +253,6 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi
{
services.TryAddEnumerable(serviceDescriptor);
}

if (serviceConfiguration.RequestExceptionActionProcessorStrategy == RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions)
{
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>), typeof(IRequestExceptionHandler<,,>));
}
else
{
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>), typeof(IRequestExceptionHandler<,,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>));
}
}

private static void RegisterBehaviorIfImplementationsExist(IServiceCollection services, Type behaviorType, Type subBehaviorType)
Expand Down
102 changes: 94 additions & 8 deletions test/MediatR.Tests/MicrosoftExtensionsDI/PipelineTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,11 @@ public class NotAnOpenBehavior : IPipelineBehavior<Ping, Pong>
public Task<Pong> Handle(Ping request, RequestHandlerDelegate<Pong> next, CancellationToken cancellationToken) => next();
}

public class ThrowingBehavior : IPipelineBehavior<Ping, Pong>
{
public Task<Pong> Handle(Ping request, RequestHandlerDelegate<Pong> next, CancellationToken cancellationToken) => throw new Exception(request.Message);
}

public class NotAnOpenStreamBehavior : IStreamPipelineBehavior<Ping, Pong>
{
public IAsyncEnumerable<Pong> Handle(Ping request, StreamHandlerDelegate<Pong> next, CancellationToken cancellationToken) => next();
Expand Down Expand Up @@ -524,6 +529,27 @@ public void Should_pick_up_base_exception_behaviors()
output.Messages.ShouldContain("Logging generic exception");
}

[Fact]
public void Should_handle_exceptions_from_behaviors()
{
var output = new Logger();
IServiceCollection services = new ServiceCollection();
services.AddSingleton(output);
services.AddMediatR(cfg =>
{
cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly);
cfg.AddBehavior<ThrowingBehavior>();
});
var provider = services.BuildServiceProvider();

var mediator = provider.GetRequiredService<IMediator>();

Should.Throw<Exception>(async () => await mediator.Send(new Ping {Message = "Ping"}));

output.Messages.ShouldContain("Ping Logged by Generic Type");
output.Messages.ShouldContain("Logging generic exception");
}

[Fact]
public void Should_pick_up_exception_actions()
{
Expand Down Expand Up @@ -648,6 +674,16 @@ public void Should_handle_open_behavior_registration()
cfg.StreamBehaviorsToRegister[0].ImplementationFactory.ShouldBeNull();
cfg.StreamBehaviorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.StreamBehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);

var services = new ServiceCollection();

cfg.RegisterServicesFromAssemblyContaining<Ping>();

Should.NotThrow(() =>
{
services.AddMediatR(cfg);
services.BuildServiceProvider();
});
}

[Fact]
Expand All @@ -659,16 +695,26 @@ public void Should_handle_inferred_behavior_registration()

cfg.BehaviorsToRegister.Count.ShouldBe(2);

cfg.BehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IPipelineBehavior<,>));
cfg.BehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IPipelineBehavior<Ping, Pong>));
cfg.BehaviorsToRegister[0].ImplementationType.ShouldBe(typeof(InnerBehavior));
cfg.BehaviorsToRegister[0].ImplementationFactory.ShouldBeNull();
cfg.BehaviorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.BehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
cfg.BehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IPipelineBehavior<,>));
cfg.BehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IPipelineBehavior<Ping, Pong>));
cfg.BehaviorsToRegister[1].ImplementationType.ShouldBe(typeof(OuterBehavior));
cfg.BehaviorsToRegister[1].ImplementationFactory.ShouldBeNull();
cfg.BehaviorsToRegister[1].ImplementationInstance.ShouldBeNull();
cfg.BehaviorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient);

var services = new ServiceCollection();

cfg.RegisterServicesFromAssemblyContaining<Ping>();

Should.NotThrow(() =>
{
services.AddMediatR(cfg);
services.BuildServiceProvider();
});
}


Expand All @@ -681,16 +727,26 @@ public void Should_handle_inferred_stream_behavior_registration()

cfg.StreamBehaviorsToRegister.Count.ShouldBe(2);

cfg.StreamBehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior<,>));
cfg.StreamBehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior<Ping, Pong>));
cfg.StreamBehaviorsToRegister[0].ImplementationType.ShouldBe(typeof(InnerStreamBehavior));
cfg.StreamBehaviorsToRegister[0].ImplementationFactory.ShouldBeNull();
cfg.StreamBehaviorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.StreamBehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
cfg.StreamBehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior<,>));
cfg.StreamBehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior<Ping, Pong>));
cfg.StreamBehaviorsToRegister[1].ImplementationType.ShouldBe(typeof(OuterStreamBehavior));
cfg.StreamBehaviorsToRegister[1].ImplementationFactory.ShouldBeNull();
cfg.StreamBehaviorsToRegister[1].ImplementationInstance.ShouldBeNull();
cfg.StreamBehaviorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient);

var services = new ServiceCollection();

cfg.RegisterServicesFromAssemblyContaining<Ping>();

Should.NotThrow(() =>
{
services.AddMediatR(cfg);
services.BuildServiceProvider();
});
}

[Fact]
Expand All @@ -702,16 +758,26 @@ public void Should_handle_inferred_pre_processor_registration()

cfg.RequestPreProcessorsToRegister.Count.ShouldBe(2);

cfg.RequestPreProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPreProcessor<>));
cfg.RequestPreProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPreProcessor<Ping>));
cfg.RequestPreProcessorsToRegister[0].ImplementationType.ShouldBe(typeof(FirstConcretePreProcessor));
cfg.RequestPreProcessorsToRegister[0].ImplementationFactory.ShouldBeNull();
cfg.RequestPreProcessorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.RequestPreProcessorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
cfg.RequestPreProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPreProcessor<>));
cfg.RequestPreProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPreProcessor<Ping>));
cfg.RequestPreProcessorsToRegister[1].ImplementationType.ShouldBe(typeof(NextConcretePreProcessor));
cfg.RequestPreProcessorsToRegister[1].ImplementationFactory.ShouldBeNull();
cfg.RequestPreProcessorsToRegister[1].ImplementationInstance.ShouldBeNull();
cfg.RequestPreProcessorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient);

var services = new ServiceCollection();

cfg.RegisterServicesFromAssemblyContaining<Ping>();

Should.NotThrow(() =>
{
services.AddMediatR(cfg);
services.BuildServiceProvider();
});
}

[Fact]
Expand All @@ -723,16 +789,26 @@ public void Should_handle_inferred_post_processor_registration()

cfg.RequestPostProcessorsToRegister.Count.ShouldBe(2);

cfg.RequestPostProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPostProcessor<,>));
cfg.RequestPostProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPostProcessor<Ping, Pong>));
cfg.RequestPostProcessorsToRegister[0].ImplementationType.ShouldBe(typeof(FirstConcretePostProcessor));
cfg.RequestPostProcessorsToRegister[0].ImplementationFactory.ShouldBeNull();
cfg.RequestPostProcessorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.RequestPostProcessorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
cfg.RequestPostProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPostProcessor<,>));
cfg.RequestPostProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPostProcessor<Ping, Pong>));
cfg.RequestPostProcessorsToRegister[1].ImplementationType.ShouldBe(typeof(NextConcretePostProcessor));
cfg.RequestPostProcessorsToRegister[1].ImplementationFactory.ShouldBeNull();
cfg.RequestPostProcessorsToRegister[1].ImplementationInstance.ShouldBeNull();
cfg.RequestPostProcessorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient);

var services = new ServiceCollection();

cfg.RegisterServicesFromAssemblyContaining<Ping>();

Should.NotThrow(() =>
{
services.AddMediatR(cfg);
services.BuildServiceProvider();
});
}

[Fact]
Expand All @@ -756,5 +832,15 @@ public void Should_handle_open_behaviors_registration_from_a_single_type()
cfg.StreamBehaviorsToRegister[0].ImplementationFactory.ShouldBeNull();
cfg.StreamBehaviorsToRegister[0].ImplementationInstance.ShouldBeNull();
cfg.StreamBehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Singleton);

var services = new ServiceCollection();

cfg.RegisterServicesFromAssemblyContaining<Ping>();

Should.NotThrow(() =>
{
services.AddMediatR(cfg);
services.BuildServiceProvider();
});
}
}