diff --git a/reactor-core/src/main/java/reactor/core/scheduler/Schedulers.java b/reactor-core/src/main/java/reactor/core/scheduler/Schedulers.java index 9dbd985b60..02dfa31c55 100644 --- a/reactor-core/src/main/java/reactor/core/scheduler/Schedulers.java +++ b/reactor-core/src/main/java/reactor/core/scheduler/Schedulers.java @@ -34,6 +34,7 @@ import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.function.Predicate; import java.util.function.Supplier; import io.micrometer.core.instrument.MeterRegistry; @@ -121,6 +122,10 @@ public abstract class Schedulers { .map(Boolean::parseBoolean) .orElse(false); + static final Predicate DEFAULT_NON_BLOCKING_THREAD_PREDICATE = thread -> false; + + static Predicate nonBlockingThreadPredicate = DEFAULT_NON_BLOCKING_THREAD_PREDICATE; + /** * Create a {@link Scheduler} which uses a backing {@link Executor} to schedule * Runnables for async operators. @@ -659,24 +664,50 @@ public static void onHandleError(String key, BiConsumer + *
  • the thread implements {@link NonBlocking}; or
  • + *
  • any of the {@link Predicate}s registered via {@link #registerNonBlockingThreadPredicate(Predicate)} + * returns {@code true}.
  • + * * * @return {@code true} if blocking is forbidden in this thread, {@code false} otherwise */ public static boolean isInNonBlockingThread() { - return Thread.currentThread() instanceof NonBlocking; + return isNonBlockingThread(Thread.currentThread()); } /** * Check if calling a Reactor blocking API in the given {@link Thread} is forbidden - * or not, by checking if the thread implements {@link NonBlocking} (in which case it is - * forbidden and this method returns {@code true}). + * or not. This method returns {@code true} and will forbid the Reactor blocking API if + * any of the following conditions meet: + *
      + *
    • the thread implements {@link NonBlocking}; or
    • + *
    • any of the {@link Predicate}s registered via {@link #registerNonBlockingThreadPredicate(Predicate)} + * returns {@code true}.
    • + *
    * * @return {@code true} if blocking is forbidden in that thread, {@code false} otherwise */ public static boolean isNonBlockingThread(Thread t) { - return t instanceof NonBlocking; + return t instanceof NonBlocking || nonBlockingThreadPredicate.test(t); + } + + /** + * Registers the specified {@link Predicate} that determines whether it is forbidden to call + * a Reactor blocking API in a given {@link Thread} or not. + */ + public static void registerNonBlockingThreadPredicate(Predicate predicate) { + nonBlockingThreadPredicate = nonBlockingThreadPredicate.or(predicate); + } + + /** + * Unregisters all the {@link Predicate}s registered so far via + * {@link #registerNonBlockingThreadPredicate(Predicate)}. + */ + public static void resetNonBlockingThreadPredicate() { + nonBlockingThreadPredicate = DEFAULT_NON_BLOCKING_THREAD_PREDICATE; } /** diff --git a/reactor-core/src/test/java/reactor/core/scheduler/SchedulersTest.java b/reactor-core/src/test/java/reactor/core/scheduler/SchedulersTest.java index b0e41798b0..821fbdb5aa 100644 --- a/reactor-core/src/test/java/reactor/core/scheduler/SchedulersTest.java +++ b/reactor-core/src/test/java/reactor/core/scheduler/SchedulersTest.java @@ -19,6 +19,7 @@ import java.time.Duration; import java.util.List; import java.util.concurrent.Callable; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; @@ -359,10 +360,53 @@ public void isNonBlockingThreadInstanceOf() { @Test public void isInNonBlockingThreadTrue() { - new ReactorThreadFactory.NonBlockingThread(() -> assertThat(Schedulers.isInNonBlockingThread()) - .as("isInNonBlockingThread") - .isFalse(), - "isInNonBlockingThreadTrue"); + assertNonBlockingThread(ReactorThreadFactory.NonBlockingThread::new, true); + } + + @Test + public void customNonBlockingThreadPredicate() { + assertThat(Schedulers.nonBlockingThreadPredicate) + .as("nonBlockingThreadPredicate") + .isSameAs(Schedulers.DEFAULT_NON_BLOCKING_THREAD_PREDICATE); + + // The custom `Predicate` is not registered yet, + // so `CustomNonBlockingThread` will be considered blocking. + assertNonBlockingThread(CustomNonBlockingThread::new, false); + + // Now register the `Predicate` and ensure `CustomNonBlockingThread` is non-blocking. + Schedulers.registerNonBlockingThreadPredicate(t -> t instanceof CustomNonBlockingThread); + try { + assertNonBlockingThread(CustomNonBlockingThread::new, true); + } finally { + // Restore the global predicate. + Schedulers.resetNonBlockingThreadPredicate(); + } + + assertThat(Schedulers.nonBlockingThreadPredicate) + .as("nonBlockingThreadPredicate (after reset)") + .isSameAs(Schedulers.DEFAULT_NON_BLOCKING_THREAD_PREDICATE); + } + + private static void assertNonBlockingThread(BiFunction threadFactory, + boolean expectedNonBlocking) { + CompletableFuture future = new CompletableFuture<>(); + Thread thread = threadFactory.apply(() -> { + try { + assertThat(Schedulers.isInNonBlockingThread()) + .as("isInNonBlockingThread") + .isEqualTo(expectedNonBlocking); + future.complete(null); + } catch (Throwable cause) { + future.completeExceptionally(cause); + } + }, "assertNonBlockingThread"); + + assertThat(Schedulers.isNonBlockingThread(thread)) + .as("isNonBlockingThread") + .isEqualTo(expectedNonBlocking); + + thread.start(); + future.join(); } @Test @@ -1457,4 +1501,10 @@ public void dispose() { } } } + + final static class CustomNonBlockingThread extends Thread { + CustomNonBlockingThread(Runnable target, String name) { + super(target, name); + } + } }