Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow registering a custom Predicate for determining non-blocking threads #3847

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;

import io.micrometer.core.instrument.MeterRegistry;
Expand Down Expand Up @@ -121,6 +122,10 @@ public abstract class Schedulers {
.map(Boolean::parseBoolean)
.orElse(false);

static final Predicate<Thread> DEFAULT_NON_BLOCKING_THREAD_PREDICATE = thread -> false;

static Predicate<Thread> nonBlockingThreadPredicate = DEFAULT_NON_BLOCKING_THREAD_PREDICATE;

/**
* Create a {@link Scheduler} which uses a backing {@link Executor} to schedule
* Runnables for async operators.
Expand Down Expand Up @@ -659,24 +664,50 @@ public static void onHandleError(String key, BiConsumer<Thread, ? super Throwabl

/**
* Check if calling a Reactor blocking API in the current {@link Thread} is forbidden
* or not, by checking if the thread implements {@link NonBlocking} (in which case it is
* forbidden and this method returns {@code true}).
* or not. This method returns {@code true} and will forbid the Reactor blocking API if
* any of the following conditions meet:
* <ul>
* <li>the thread implements {@link NonBlocking}; or</li>
* <li>any of the {@link Predicate}s registered via {@link #registerNonBlockingThreadPredicate(Predicate)}
* returns {@code true}.</li>
* </ul>
*
* @return {@code true} if blocking is forbidden in this thread, {@code false} otherwise
*/
public static boolean isInNonBlockingThread() {
return Thread.currentThread() instanceof NonBlocking;
return isNonBlockingThread(Thread.currentThread());
}

/**
* Check if calling a Reactor blocking API in the given {@link Thread} is forbidden
* or not, by checking if the thread implements {@link NonBlocking} (in which case it is
* forbidden and this method returns {@code true}).
* or not. This method returns {@code true} and will forbid the Reactor blocking API if
* any of the following conditions meet:
* <ul>
* <li>the thread implements {@link NonBlocking}; or</li>
* <li>any of the {@link Predicate}s registered via {@link #registerNonBlockingThreadPredicate(Predicate)}
* returns {@code true}.</li>
* </ul>
*
* @return {@code true} if blocking is forbidden in that thread, {@code false} otherwise
*/
public static boolean isNonBlockingThread(Thread t) {
return t instanceof NonBlocking;
return t instanceof NonBlocking || nonBlockingThreadPredicate.test(t);
}

/**
* Registers the specified {@link Predicate} that determines whether it is forbidden to call
* a Reactor blocking API in a given {@link Thread} or not.
*/
public static void registerNonBlockingThreadPredicate(Predicate<Thread> predicate) {
nonBlockingThreadPredicate = nonBlockingThreadPredicate.or(predicate);
}

/**
* Unregisters all the {@link Predicate}s registered so far via
* {@link #registerNonBlockingThreadPredicate(Predicate)}.
*/
public static void resetNonBlockingThreadPredicate() {
nonBlockingThreadPredicate = DEFAULT_NON_BLOCKING_THREAD_PREDICATE;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.time.Duration;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
Expand Down Expand Up @@ -359,10 +360,53 @@ public void isNonBlockingThreadInstanceOf() {

@Test
public void isInNonBlockingThreadTrue() {
new ReactorThreadFactory.NonBlockingThread(() -> assertThat(Schedulers.isInNonBlockingThread())
.as("isInNonBlockingThread")
.isFalse(),
"isInNonBlockingThreadTrue");
assertNonBlockingThread(ReactorThreadFactory.NonBlockingThread::new, true);
}

@Test
public void customNonBlockingThreadPredicate() {
assertThat(Schedulers.nonBlockingThreadPredicate)
.as("nonBlockingThreadPredicate")
.isSameAs(Schedulers.DEFAULT_NON_BLOCKING_THREAD_PREDICATE);

// The custom `Predicate` is not registered yet,
// so `CustomNonBlockingThread` will be considered blocking.
assertNonBlockingThread(CustomNonBlockingThread::new, false);

// Now register the `Predicate` and ensure `CustomNonBlockingThread` is non-blocking.
Schedulers.registerNonBlockingThreadPredicate(t -> t instanceof CustomNonBlockingThread);
try {
assertNonBlockingThread(CustomNonBlockingThread::new, true);
} finally {
// Restore the global predicate.
Schedulers.resetNonBlockingThreadPredicate();
}

assertThat(Schedulers.nonBlockingThreadPredicate)
.as("nonBlockingThreadPredicate (after reset)")
.isSameAs(Schedulers.DEFAULT_NON_BLOCKING_THREAD_PREDICATE);
}

private static void assertNonBlockingThread(BiFunction<Runnable, String, Thread> threadFactory,
boolean expectedNonBlocking) {
CompletableFuture<Void> future = new CompletableFuture<>();
Thread thread = threadFactory.apply(() -> {
try {
assertThat(Schedulers.isInNonBlockingThread())
.as("isInNonBlockingThread")
.isEqualTo(expectedNonBlocking);
future.complete(null);
} catch (Throwable cause) {
future.completeExceptionally(cause);
}
Comment on lines +398 to +401
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this pattern, thanks! In other places we set the check using some atomic type, utilize a latch and then combine the await with an assertion. This allows to use the assertions API directly, awesome.

}, "assertNonBlockingThread");

assertThat(Schedulers.isNonBlockingThread(thread))
.as("isNonBlockingThread")
.isEqualTo(expectedNonBlocking);

thread.start();
future.join();
}

@Test
Expand Down Expand Up @@ -1457,4 +1501,10 @@ public void dispose() {
}
}
}

final static class CustomNonBlockingThread extends Thread {
CustomNonBlockingThread(Runnable target, String name) {
super(target, name);
}
}
}