Skip to content

Commit

Permalink
Chunking for GreedyScheduler (#77)
Browse files Browse the repository at this point in the history
* fix rebase conflict

* fix rebase conflict

* propagate PR#84 to chunking variant

* changelog + NotGiven instead of nothing

* default to 10*nthreads for chunking with greedy scheduler

* changelog
  • Loading branch information
carstenbauer committed Mar 18, 2024
1 parent f3a3838 commit 989f0e9
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Version 0.5.0
- ![Enhancement][badge-enhancement] Uses of `@local` within `@tasks` no-longer require users to declare the type of the task local value, it can be inferred automatically if a type is not provided.
- ![Enhancement][badge-enhancement] Made `using OhMyThreads: ...` more explicit in examples in the documentation and docstrings.
- ![BREAKING][badge-breaking] The `DynamicScheduler` (default) and the `StaticScheduler` now support a `chunksize` argument to specify the desired size of chunks instead of the number of chunks (`nchunks`). Note that `chunksize` and `nchunks` are mutually exclusive. (This is unlikely to break existing code but technically could because the type parameter has changed from `Bool` to `ChunkingMode`.)
- ![BREAKING][badge-breaking] The greedy scheduler now supports chunking (similar to the static and dynamic scheduler). You can opt into it with, e.g., `chunking=true`. (This is unlikely to break existing code but technically could because we introduced a new type parameter for `GreedyScheduler`.)
- ![Breaking][badge-breaking] `DynamicScheduler` and `StaticScheduler` don't support `nchunks=0` or `chunksize=0` any longer. Instead, chunking can now be turned off via an explicit new keyword argument `chunking=false`.
- ![BREAKING][badge-breaking] Within a `@tasks` block, task-local values must from now on be defined via `@local` instead of `@init` (renamed).
- ![BREAKING][badge-breaking] The (already deprecated) `SpawnAllScheduler` has been dropped.
Expand Down
56 changes: 51 additions & 5 deletions src/implementation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,12 @@ end
# NOTE: once v1.12 releases we should switch this to wait(t; throw=false)
wait_nothrow(t) = Base._wait(t)

# GreedyScheduler
# GreedyScheduler w/o chunking
function _tmapreduce(f,
op,
Arrs,
::Type{OutputType},
scheduler::GreedyScheduler,
scheduler::GreedyScheduler{NoChunking},
mapreduce_kwargs)::OutputType where {OutputType}
ntasks_desired = scheduler.ntasks
if Base.IteratorSize(first(Arrs)) isa Base.SizeUnknown
Expand Down Expand Up @@ -233,6 +233,55 @@ function _tmapreduce(f,
mapreduce(fetch, promise_task_local(op), filtered_tasks; mapreduce_kwargs...)
end

# GreedyScheduler w/ chunking
function _tmapreduce(f,
op,
Arrs,
::Type{OutputType},
scheduler::GreedyScheduler,
mapreduce_kwargs)::OutputType where {OutputType}
if Base.IteratorSize(first(Arrs)) isa Base.SizeUnknown
throw(ArgumentError("SizeUnkown iterators in combination with a greedy scheduler and chunking are currently not supported."))
end
check_all_have_same_indices(Arrs)
chnks = _chunks(scheduler, first(Arrs))
ntasks_desired = scheduler.ntasks
ntasks = min(length(chnks), ntasks_desired)

ch = Channel{typeof(first(chnks))}(length(chnks); spawn = true) do ch
for args in chnks
put!(ch, args)
end
end
tasks = map(1:ntasks) do _
# Note, calling `promise_task_local` here is only safe because we're assuming that
# Base.mapreduce isn't going to magically try to do multithreading on us...
@spawn mapreduce(promise_task_local(op), ch; mapreduce_kwargs...) do inds
args = map(A -> view(A, inds), Arrs)
mapreduce(promise_task_local(f), promise_task_local(op), args...)
end
end
# Doing this because of https://github.com/JuliaFolds2/OhMyThreads.jl/issues/82
# The idea is that if the channel gets fully consumed before a task gets started up,
# then if the user does not supply an `init` kwarg, we'll get an error.
# Current way of dealing with this is just filtering out `mapreduce_empty` method
# errors. This may not be the most stable way of dealing with things, e.g. if the
# name of the function throwing the error changes this could break, so long term
# we may want to try a different design.
filtered_tasks = filter(tasks) do stabletask
task = stabletask.t
istaskdone(task) || wait_nothrow(task)
if task.result isa MethodError && task.result.f == Base.mapreduce_empty
false
else
true
end
end
# Note, calling `promise_task_local` here is only safe because we're assuming that
# Base.mapreduce isn't going to magically try to do multithreading on us...
mapreduce(fetch, promise_task_local(op), filtered_tasks; mapreduce_kwargs...)
end

function check_all_have_same_indices(Arrs)
let A = first(Arrs), Arrs = Arrs[2:end]
if !all(B -> eachindex(A) == eachindex(B), Arrs)
Expand Down Expand Up @@ -402,9 +451,6 @@ end
kwargs...)
_scheduler = _scheduler_from_userinput(scheduler; kwargs...)

if hasfield(typeof(_scheduler), :split) && _scheduler.split != :batch
error("Only `split == :batch` is supported because the parallel operation isn't commutative. (Scheduler: $_scheduler)")
end
Arrs = (A, _Arrs...)
if _scheduler isa SerialScheduler
map!(f, out, Arrs...)
Expand Down
69 changes: 62 additions & 7 deletions src/schedulers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ end
GreedyScheduler (aka :greedy)
A greedy dynamic scheduler. The elements of the collection are first put into a `Channel`
and then dynamic, non-sticky tasks are spawned to process channel content in parallel.
and then dynamic, non-sticky tasks are spawned to process the channel content in parallel.
Note that elements are processed in a non-deterministic order, and thus a potential reducing
function **must** be [commutative](https://en.wikipedia.org/wiki/Commutative_property) in
Expand All @@ -228,20 +228,75 @@ some additional overhead.
- `ntasks::Int` (default `nthreads()`):
* Determines the number of parallel tasks to be spawned.
* Setting `nchunks < nthreads()` is an effective way to use only a subset of the available threads.
* Setting `ntasks < nthreads()` is an effective way to use only a subset of the available threads.
- `chunking::Bool` (default `false`):
* Controls whether input elements are grouped into chunks (`true`) or not (`false`) before put into the channel. This can improve the performance especially if there are many iterations each of which are computationally cheap.
* If `nchunks` or `chunksize` are explicitly specified, `chunking` will be automatically set to `true`.
- `nchunks::Integer` (default `10 * nthreads()`):
* Determines the number of chunks (that will eventually be put into the channel).
* Increasing `nchunks` can help with [load balancing](https://en.wikipedia.org/wiki/Load_balancing_(computing)). For `nchunks <= nthreads()` there are not enough chunks for any load balancing.
- `chunksize::Integer` (default not set)
* Specifies the desired chunk size (instead of the number of chunks).
* The options `chunksize` and `nchunks` are **mutually exclusive** (only one may be a positive integer).
- `split::Symbol` (default `:scatter`):
* Determines how the collection is divided into chunks (if chunking=true).
* See [ChunkSplitters.jl](https://github.com/JuliaFolds2/ChunkSplitters.jl) for more details and available options.
"""
Base.@kwdef struct GreedyScheduler <: Scheduler
ntasks::Int = nthreads()
struct GreedyScheduler{C <: ChunkingMode} <: Scheduler
ntasks::Int
nchunks::Int
chunksize::Int
split::Symbol

function GreedyScheduler(ntasks::Int)
function GreedyScheduler(ntasks::Int, nchunks::Integer, chunksize::Integer,
split::Symbol; chunking::Bool = false)
ntasks > 0 || throw(ArgumentError("ntasks must be a positive integer"))
new(ntasks)
if !chunking
C = NoChunking
else
if !(nchunks > 0 || chunksize > 0)
throw(ArgumentError("Either nchunks or chunksize must be a positive integer (or chunking=false)."))
end
if nchunks > 0 && chunksize > 0
throw(ArgumentError("nchunks and chunksize are mutually exclusive and only one of them may be a positive integer"))
end
C = chunksize > 0 ? FixedSize : FixedCount
end
new{C}(ntasks, nchunks, chunksize, split)
end
end

function GreedyScheduler(;
ntasks::Integer = nthreads(),
nchunks::MaybeInteger = NotGiven(),
chunksize::MaybeInteger = NotGiven(),
chunking::Bool = false,
split::Symbol = :scatter)
if isgiven(nchunks) || isgiven(chunksize)
chunking = true
end
if !chunking
nchunks = -1
chunksize = -1
else
# only choose nchunks default if chunksize hasn't been specified
if !isgiven(nchunks) && !isgiven(chunksize)
nchunks = 10 * nthreads(:default)
chunksize = -1
else
nchunks = isgiven(nchunks) ? nchunks :
isgiven(ntasks) ? ntasks : -1
chunksize = isgiven(chunksize) ? chunksize : -1
end
end
GreedyScheduler(ntasks, nchunks, chunksize, split; chunking)
end

function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, s::GreedyScheduler)
print(io, "GreedyScheduler", "\n")
println(io, "├ Num. tasks: ", s.ntasks)
cstr = _chunkingstr(s)
println(io, "├ Chunking: ", cstr)
print(io, "└ Threadpool: default")
end

Expand All @@ -258,7 +313,7 @@ end
chunking_mode(s::Scheduler) = chunking_mode(typeof(s))
chunking_mode(::Type{DynamicScheduler{C}}) where {C} = C
chunking_mode(::Type{StaticScheduler{C}}) where {C} = C
chunking_mode(::Type{GreedyScheduler}) = NoChunking
chunking_mode(::Type{GreedyScheduler{C}}) where {C} = C
chunking_mode(::Type{SerialScheduler}) = NoChunking

chunking_enabled(s::Scheduler) = chunking_enabled(typeof(s))
Expand Down
10 changes: 5 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using Test, OhMyThreads
using OhMyThreads: TaskLocalValue, WithTaskLocals, @fetch, promise_task_local



sets_to_test = [(~ = isapprox, f = sin *, op = +,
itrs = (rand(ComplexF64, 10, 10), rand(-10:10, 10, 10)),
init = complex(0.0))
Expand All @@ -12,11 +10,13 @@ sets_to_test = [(~ = isapprox, f = sin ∘ *, op = +,
itrs = ([1 => "a", 2 => "b", 3 => "c", 4 => "d", 5 => "e"],),
init = "")]

ChunkedGreedy(;kwargs...) = GreedyScheduler(;kwargs...)

@testset "Basics" begin
for (; ~, f, op, itrs, init) in sets_to_test
@testset "f=$f, op=$op, itrs::$(typeof(itrs))" begin
@testset for sched in (
StaticScheduler, DynamicScheduler, GreedyScheduler, DynamicScheduler{OhMyThreads.Schedulers.NoChunking}, SerialScheduler)
StaticScheduler, DynamicScheduler, GreedyScheduler, DynamicScheduler{OhMyThreads.Schedulers.NoChunking}, SerialScheduler, ChunkedGreedy)
@testset for split in (:batch, :scatter)
for nchunks in (1, 2, 6)
if sched == GreedyScheduler
Expand All @@ -30,7 +30,7 @@ sets_to_test = [(~ = isapprox, f = sin ∘ *, op = +,
end

kwargs = (; scheduler)
if (split == :scatter || sched == GreedyScheduler) || op (vcat, *)
if (split == :scatter || sched (GreedyScheduler, ChunkedGreedy)) || op (vcat, *)
# scatter and greedy only works for commutative operators!
else
mapreduce_f_op_itr = mapreduce(f, op, itrs...)
Expand All @@ -51,7 +51,7 @@ sets_to_test = [(~ = isapprox, f = sin ∘ *, op = +,
@test tcollect(RT, (f(x...) for x in collect(zip(itrs...))); kwargs...) ~ map_f_itr
@test tcollect(RT, f.(itrs...); kwargs...) ~ map_f_itr

if sched !== GreedyScheduler
if sched (GreedyScheduler, ChunkedGreedy)
@test tmap(f, itrs...; kwargs...) ~ map_f_itr
@test tcollect((f(x...) for x in collect(zip(itrs...))); kwargs...) ~ map_f_itr
@test tcollect(f.(itrs...); kwargs...) ~ map_f_itr
Expand Down

0 comments on commit 989f0e9

Please sign in to comment.