From 169e9e8de1b1429e8efbacfd6274c2a5399989bf Mon Sep 17 00:00:00 2001 From: Gabriel Baraldi Date: Mon, 9 Sep 2024 12:10:33 -0300 Subject: [PATCH] Implement faster thread local rng for scheduler (#55501) Implement optimal uniform random number generator using the method proposed in https://github.com/swiftlang/swift/pull/39143 based on OpenSSL's implementation of it in https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99 This PR also fixes some bugs found while developing it. This is a replacement for https://github.com/JuliaLang/julia/pull/50203 and fixes the issues found by @IanButterworth with both rngs C rng image New scheduler rng image ~On my benchmarks the julia implementation seems to be almost 50% faster than the current implementation.~ With oscars suggestion of removing the debiasing this is now almost 5x faster than the original implementation. And almost fully branchless We might want to backport the two previous commits since they technically fix bugs. --------- Co-authored-by: Valentin Churavy --- base/partr.jl | 55 ++++++++++++++++++++++++++++++++++++++- src/ccall.cpp | 32 +++++++++++++++++++++++ src/jl_exported_funcs.inc | 2 ++ src/julia_threads.h | 2 ++ src/scheduler.c | 9 ------- src/threading.c | 12 +++++++++ 6 files changed, 102 insertions(+), 10 deletions(-) diff --git a/base/partr.jl b/base/partr.jl index 8c95e3668ee74..6053a584af5ba 100644 --- a/base/partr.jl +++ b/base/partr.jl @@ -20,7 +20,60 @@ const heaps = [Vector{taskheap}(undef, 0), Vector{taskheap}(undef, 0)] const heaps_lock = [SpinLock(), SpinLock()] -cong(max::UInt32) = iszero(max) ? UInt32(0) : ccall(:jl_rand_ptls, UInt32, (UInt32,), max) + UInt32(1) +""" + cong(max::UInt32) + +Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0. +""" +cong(max::UInt32) = iszero(max) ? UInt32(0) : rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check + +get_ptls_rng() = ccall(:jl_get_ptls_rng, UInt64, ()) + +set_ptls_rng(seed::UInt64) = ccall(:jl_set_ptls_rng, Cvoid, (UInt64,), seed) + +""" + rand_ptls(max::UInt32) + +Return a random UInt32 in the range `0:max-1` using the thread-local RNG +state. Max must be greater than 0. +""" +Base.@assume_effects :removable :inaccessiblememonly :notaskstate function rand_ptls(max::UInt32) + rngseed = get_ptls_rng() + val, seed = rand_uniform_max_int32(max, rngseed) + set_ptls_rng(seed) + return val % UInt32 +end + +# This implementation is based on OpenSSLs implementation of rand_uniform +# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99 +# Comments are vendored from their implementation as well. +# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143. + +# Essentially it boils down to incrementally generating a fixed point +# number on the interval [0, 1) and multiplying this number by the upper +# range limit. Once it is certain what the fractional part contributes to +# the integral part of the product, the algorithm has produced a definitive +# result. +""" + rand_uniform_max_int32(max::UInt32, seed::UInt64) + +Return a random UInt32 in the range `0:max-1` using the given seed. +Max must be greater than 0. +""" +Base.@assume_effects :total function rand_uniform_max_int32(max::UInt32, seed::UInt64) + if max == UInt32(1) + return UInt32(0), seed + end + # We are generating a fixed point number on the interval [0, 1). + # Multiplying this by the range gives us a number on [0, upper). + # The high word of the multiplication result represents the integral part + # This is not completely unbiased as it's missing the fractional part of the original implementation but it's good enough for our purposes + seed = UInt64(69069) * seed + UInt64(362437) + prod = (UInt64(max)) * (seed % UInt32) # 64 bit product + i = prod >> 32 % UInt32 # integral part + return i % UInt32, seed +end + function multiq_sift_up(heap::taskheap, idx::Int32) diff --git a/src/ccall.cpp b/src/ccall.cpp index 36808e13fdbf9..7ab8cfa974d6f 100644 --- a/src/ccall.cpp +++ b/src/ccall.cpp @@ -22,6 +22,8 @@ TRANSFORMED_CCALL_STAT(jl_cpu_wake); TRANSFORMED_CCALL_STAT(jl_gc_safepoint); TRANSFORMED_CCALL_STAT(jl_get_ptls_states); TRANSFORMED_CCALL_STAT(jl_threadid); +TRANSFORMED_CCALL_STAT(jl_get_ptls_rng); +TRANSFORMED_CCALL_STAT(jl_set_ptls_rng); TRANSFORMED_CCALL_STAT(jl_get_tls_world_age); TRANSFORMED_CCALL_STAT(jl_get_world_counter); TRANSFORMED_CCALL_STAT(jl_gc_enable_disable_finalizers_internal); @@ -1692,6 +1694,36 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs) ai.decorateInst(tid); return mark_or_box_ccall_result(ctx, tid, retboxed, rt, unionall, static_rt); } + else if (is_libjulia_func(jl_get_ptls_rng)) { + ++CCALL_STAT(jl_get_ptls_rng); + assert(lrt == getInt64Ty(ctx.builder.getContext())); + assert(!isVa && !llvmcall && nccallargs == 0); + JL_GC_POP(); + Value *ptls_p = get_current_ptls(ctx); + const int rng_offset = offsetof(jl_tls_states_t, rngseed); + Value *rng_ptr = ctx.builder.CreateInBoundsGEP(getInt8Ty(ctx.builder.getContext()), ptls_p, ConstantInt::get(ctx.types().T_size, rng_offset / sizeof(int8_t))); + setName(ctx.emission_context, rng_ptr, "rngseed_ptr"); + LoadInst *rng_value = ctx.builder.CreateAlignedLoad(getInt64Ty(ctx.builder.getContext()), rng_ptr, Align(sizeof(void*))); + setName(ctx.emission_context, rng_value, "rngseed"); + jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_gcframe); + ai.decorateInst(rng_value); + return mark_or_box_ccall_result(ctx, rng_value, retboxed, rt, unionall, static_rt); + } + else if (is_libjulia_func(jl_set_ptls_rng)) { + ++CCALL_STAT(jl_set_ptls_rng); + assert(lrt == getVoidTy(ctx.builder.getContext())); + assert(!isVa && !llvmcall && nccallargs == 1); + JL_GC_POP(); + Value *ptls_p = get_current_ptls(ctx); + const int rng_offset = offsetof(jl_tls_states_t, rngseed); + Value *rng_ptr = ctx.builder.CreateInBoundsGEP(getInt8Ty(ctx.builder.getContext()), ptls_p, ConstantInt::get(ctx.types().T_size, rng_offset / sizeof(int8_t))); + setName(ctx.emission_context, rng_ptr, "rngseed_ptr"); + assert(argv[0].V->getType() == getInt64Ty(ctx.builder.getContext())); + auto store = ctx.builder.CreateAlignedStore(argv[0].V, rng_ptr, Align(sizeof(void*))); + jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_gcframe); + ai.decorateInst(store); + return ghostValue(ctx, jl_nothing_type); + } else if (is_libjulia_func(jl_get_tls_world_age)) { bool toplevel = !(ctx.linfo && jl_is_method(ctx.linfo->def.method)); if (!toplevel) { // top level code does not see a stable world age during execution diff --git a/src/jl_exported_funcs.inc b/src/jl_exported_funcs.inc index 7f1636ad9ad80..7abf2b055bb8c 100644 --- a/src/jl_exported_funcs.inc +++ b/src/jl_exported_funcs.inc @@ -452,6 +452,8 @@ XX(jl_test_cpu_feature) \ XX(jl_threadid) \ XX(jl_threadpoolid) \ + XX(jl_get_ptls_rng) \ + XX(jl_set_ptls_rng) \ XX(jl_throw) \ XX(jl_throw_out_of_memory_error) \ XX(jl_too_few_args) \ diff --git a/src/julia_threads.h b/src/julia_threads.h index 7c6de1896ca13..b697a0bf030ed 100644 --- a/src/julia_threads.h +++ b/src/julia_threads.h @@ -18,6 +18,8 @@ extern "C" { JL_DLLEXPORT int16_t jl_threadid(void); JL_DLLEXPORT int8_t jl_threadpoolid(int16_t tid) JL_NOTSAFEPOINT; +JL_DLLEXPORT uint64_t jl_get_ptls_rng(void) JL_NOTSAFEPOINT; +JL_DLLEXPORT void jl_set_ptls_rng(uint64_t new_seed) JL_NOTSAFEPOINT; // JULIA_ENABLE_THREADING may be controlled by altering JULIA_THREADS in Make.user diff --git a/src/scheduler.c b/src/scheduler.c index bd7da13aa42e3..bb2f85b52283f 100644 --- a/src/scheduler.c +++ b/src/scheduler.c @@ -84,15 +84,6 @@ JL_DLLEXPORT int jl_set_task_threadpoolid(jl_task_t *task, int8_t tpid) JL_NOTSA extern int jl_gc_mark_queue_obj_explicit(jl_gc_mark_cache_t *gc_cache, jl_gc_markqueue_t *mq, jl_value_t *obj) JL_NOTSAFEPOINT; -// parallel task runtime -// --- - -JL_DLLEXPORT uint32_t jl_rand_ptls(uint32_t max) // [0, n) -{ - jl_ptls_t ptls = jl_current_task->ptls; - return cong(max, &ptls->rngseed); -} - // initialize the threading infrastructure // (called only by the main thread) void jl_init_threadinginfra(void) diff --git a/src/threading.c b/src/threading.c index 2f3719b89fac3..44b1192528531 100644 --- a/src/threading.c +++ b/src/threading.c @@ -314,6 +314,18 @@ JL_DLLEXPORT int8_t jl_threadpoolid(int16_t tid) JL_NOTSAFEPOINT return -1; // everything else uses threadpool -1 (does not belong to any threadpool) } +// get thread local rng +JL_DLLEXPORT uint64_t jl_get_ptls_rng(void) JL_NOTSAFEPOINT +{ + return jl_current_task->ptls->rngseed; +} + +// get thread local rng +JL_DLLEXPORT void jl_set_ptls_rng(uint64_t new_seed) JL_NOTSAFEPOINT +{ + jl_current_task->ptls->rngseed = new_seed; +} + jl_ptls_t jl_init_threadtls(int16_t tid) { #ifndef _OS_WINDOWS_