From 94e38d785553ff80ba6ac264a38596dd73e3d227 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Mon, 28 Jul 2025 13:33:31 -0700 Subject: [PATCH] Add a `ParallelFinalReduction` option for slow reducing operators. --- src/OhMyThreads.jl | 4 +- src/implementation.jl | 47 +++++++++++++++++------ src/schedulers.jl | 89 +++++++++++++++++++++++++++++++++++-------- test/runtests.jl | 13 ++++--- 4 files changed, 121 insertions(+), 32 deletions(-) diff --git a/src/OhMyThreads.jl b/src/OhMyThreads.jl index 5f9bbd2..f5511f3 100644 --- a/src/OhMyThreads.jl +++ b/src/OhMyThreads.jl @@ -25,12 +25,14 @@ include("macros.jl") include("tools.jl") include("schedulers.jl") using .Schedulers: Scheduler, DynamicScheduler, StaticScheduler, GreedyScheduler, - SerialScheduler + SerialScheduler, FinalReductionMode, SerialFinalReduction, + ParallelFinalReduction include("implementation.jl") include("experimental.jl") export @tasks, @set, @local, @one_by_one, @only_one, @allow_boxed_captures, @disallow_boxed_captures, @localize export treduce, tmapreduce, treducemap, tmap, tmap!, tforeach, tcollect export Scheduler, DynamicScheduler, StaticScheduler, GreedyScheduler, SerialScheduler +export FinalReductionMode, SerialFinalReduction, ParallelFinalReduction end # module OhMyThreads diff --git a/src/implementation.jl b/src/implementation.jl index 967e90c..95075a5 100644 --- a/src/implementation.jl +++ b/src/implementation.jl @@ -10,7 +10,9 @@ using OhMyThreads.Schedulers: chunking_enabled, nchunks, chunksize, chunksplit, minchunksize, has_chunksplit, chunking_mode, ChunkingMode, NoChunking, FixedSize, FixedCount, scheduler_from_symbol, NotGiven, - isgiven + isgiven, + FinalReductionMode, + SerialFinalReduction, ParallelFinalReduction using Base: @propagate_inbounds using Base.Threads: nthreads, @threads using BangBang: append!! @@ -86,7 +88,6 @@ function has_multiple_chunks(scheduler, coll) end end - function tmapreduce(f, op, Arrs...; scheduler::MaybeScheduler = NotGiven(), outputtype::Type = Any, @@ -115,6 +116,35 @@ end treducemap(op, f, A...; kwargs...) = tmapreduce(f, op, A...; kwargs...) +function tree_mapreduce(f, op, v) + if length(v) == 1 + f(only(v)) + elseif length(v) == 2 + op(f(v[1]), f(v[2])) + else + l, r = v[begin:(end-begin)÷2], v[((end-begin)÷2+1):end] + task_r = @spawn tree_mapreduce(f, op, r) + result_l = tree_mapreduce(f, op, l) + op(result_l, fetch(task_r)) + end +end + +function final_mapreduce(op, tasks, ::SerialFinalReduction; mapreduce_kwargs...) + # Note, calling `promise_task_local` here is only safe because we're assuming that + # Base.mapreduce isn't going to magically try to do multithreading on us... + mapreduce(fetch, promise_task_local(op), tasks; mapreduce_kwargs...) +end +function final_mapreduce(op, tasks, ::ParallelFinalReduction; mapreduce_kwargs...) + if isempty(tasks) + # Note, calling `promise_task_local` here is only safe because we're assuming that + # Base.mapreduce isn't going to magically try to do multithreading on us... + mapreduce(fetch, promise_task_local(op), tasks; mapreduce_kwargs...) + else + tree_mapreduce(fetch, op, tasks; mapreduce_kwargs...) + end +end + + # DynamicScheduler: AbstractArray/Generic function _tmapreduce(f, op, @@ -134,14 +164,13 @@ function _tmapreduce(f, @spawn threadpool mapreduce(promise_task_local(f), promise_task_local(op), args...; $mapreduce_kwargs...) end - mapreduce(fetch, promise_task_local(op), tasks) else tasks = map(eachindex(first(Arrs))) do i args = map(A -> @inbounds(A[i]), Arrs) @spawn threadpool promise_task_local(f)(args...) end - mapreduce(fetch, promise_task_local(op), tasks; mapreduce_kwargs...) end + final_mapreduce(op, tasks, FinalReductionMode(scheduler); mapreduce_kwargs...) end # DynamicScheduler: AbstractChunks @@ -156,7 +185,7 @@ function _tmapreduce(f, tasks = map(only(Arrs)) do idcs @spawn threadpool promise_task_local(f)(idcs) end - mapreduce(fetch, promise_task_local(op), tasks; mapreduce_kwargs...) + final_mapreduce(op, tasks, FinalReductionMode(scheduler); mapreduce_kwargs...) end # StaticScheduler: AbstractArray/Generic @@ -284,9 +313,7 @@ function _tmapreduce(f, true end end - # Note, calling `promise_task_local` here is only safe because we're assuming that - # Base.mapreduce isn't going to magically try to do multithreading on us... - mapreduce(fetch, promise_task_local(op), filtered_tasks; mapreduce_kwargs...) + final_mapreduce(op, filtered_tasks, FinalReductionMode(scheduler); mapreduce_kwargs...) end # GreedyScheduler w/ chunking @@ -332,9 +359,7 @@ function _tmapreduce(f, true end end - # Note, calling `promise_task_local` here is only safe because we're assuming that - # Base.mapreduce isn't going to magically try to do multithreading on us... - mapreduce(fetch, promise_task_local(op), filtered_tasks; mapreduce_kwargs...) + final_mapreduce(op, filtered_tasks, FinalReductionMode(scheduler); mapreduce_kwargs...) end function check_all_have_same_indices(Arrs) diff --git a/src/schedulers.jl b/src/schedulers.jl index 4ae481c..501fc8f 100644 --- a/src/schedulers.jl +++ b/src/schedulers.jl @@ -170,6 +170,57 @@ kind of scheduler. function default_nchunks end default_nchunks(::Type{<:Scheduler}) = nthreads(:default) + + +tree_str = raw""" +``` +t1 t2 t3 t4 t5 t6 + \ | | / | / + \ | | / | / + op op op + \ / / + \ / / + op / + \ / + op +``` +""" + +""" + FinalReductionMode + +A trait type to decide how the final reduction is performed. Essentially, +OhMyThreads.jl will turn a `tmapreduce(f, op, v)` call into something of +the form +```julia +tasks = map(chunks(v; chunking_kwargs...)) do chunk + @spawn mapreduce(f, op, chunk) +end +final_reduction(op, tasks, ReductionMode) +``` +where the options for `ReductionMode` are currently + +* `SerialFinalReduction` is the default option that should be preferred whenever `op` is not the bottleneck in your reduction. In this mode, we use a simple `mapreduce` over the tasks vector, fetching each one, i.e. +```julia +function final_reduction(op, tasks, ::SerialFinalReduction) + mapreduce(fetch, op, tasks) +end +``` + +* `ParallelFinalReduction` should be opted into when `op` takes a long time relative to the time it takes to `@spawn` and `fetch` tasks (typically tens of microseconds). In this mode, the vector of tasks is split up and `op` is applied in parallel using a recursive tree-based approach. +$tree_str +""" +abstract type FinalReductionMode end +struct SerialFinalReduction <: FinalReductionMode end +struct ParallelFinalReduction <: FinalReductionMode end + +FinalReductionMode(s::Scheduler) = s.final_reduction_mode + +FinalReductionMode(s::Symbol) = FinalReductionMode(Val(s)) +FinalReductionMode(::Val{:serial}) = SerialFinalReduction() +FinalReductionMode(::Val{:parallel}) = ParallelFinalReduction() +FinalReductionMode(m::FinalReductionMode) = m + """ DynamicScheduler (aka :dynamic) @@ -202,16 +253,17 @@ with other multithreaded code. - `threadpool::Symbol` (default `:default`): * Possible options are `:default` and `:interactive`. * The high-priority pool `:interactive` should be used very carefully since tasks on this threadpool should not be allowed to run for a long time without `yield`ing as it can interfere with [heartbeat](https://en.wikipedia.org/wiki/Heartbeat_(computing)) processes. +- `final_reduction_mode` (default `SerialFinalReduction`). Switch this to `ParallelFinalReduction` or `:parallel` if your reducing operator `op` is significantly slower than the time to `@spawn` and `fetch` tasks (typically tens of microseconds). """ -struct DynamicScheduler{C <: ChunkingMode, S <: Split} <: Scheduler +struct DynamicScheduler{C <: ChunkingMode, S <: Split, FRM <: FinalReductionMode} <: Scheduler threadpool::Symbol chunking_args::ChunkingArgs{C, S} - - function DynamicScheduler(threadpool::Symbol, ca::ChunkingArgs) + final_reduction_mode::FRM + function DynamicScheduler(threadpool::Symbol, ca::ChunkingArgs, frm=SerialFinalReduction()) if !(threadpool in (:default, :interactive)) throw(ArgumentError("threadpool must be either :default or :interactive")) end - new{chunking_mode(ca), typeof(ca.split)}(threadpool, ca) + new{chunking_mode(ca), typeof(ca.split), typeof(frm)}(threadpool, ca, frm) end end @@ -222,7 +274,8 @@ function DynamicScheduler(; chunksize::MaybeInteger = NotGiven(), chunking::Bool = true, split::Union{Split, Symbol} = Consecutive(), - minchunksize::Union{Nothing, Int}=nothing) + minchunksize::Union{Nothing, Int}=nothing, + final_reduction_mode::Union{Symbol, FinalReductionMode}=SerialFinalReduction()) if isgiven(ntasks) if isgiven(nchunks) throw(ArgumentError("For the dynamic scheduler, nchunks and ntasks are aliases and only one may be provided")) @@ -230,7 +283,8 @@ function DynamicScheduler(; nchunks = ntasks end ca = ChunkingArgs(DynamicScheduler, nchunks, chunksize, split; chunking, minsize=minchunksize) - return DynamicScheduler(threadpool, ca) + frm = FinalReductionMode(final_reduction_mode) + return DynamicScheduler(threadpool, ca, frm) end from_symbol(::Val{:dynamic}) = DynamicScheduler chunking_args(sched::DynamicScheduler) = sched.chunking_args @@ -239,7 +293,8 @@ function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, s::DynamicScheduler print(io, "DynamicScheduler", "\n") cstr = _chunkingstr(s.chunking_args) println(io, "├ Chunking: ", cstr) - print(io, "└ Threadpool: ", s.threadpool) + println(io, "├ Threadpool: ", s.threadpool) + print(io, "└ FinalReductionMode: ", FinalReductionMode(s)) end """ @@ -336,14 +391,16 @@ some additional overhead. - `split::Union{Symbol, OhMyThreads.Split}` (default `OhMyThreads.RoundRobin()`): * Determines how the collection is divided into chunks (if chunking=true). * See [ChunkSplitters.jl](https://github.com/JuliaFolds2/ChunkSplitters.jl) for more details and available options. We also allow users to pass `:consecutive` in place of `Consecutive()`, and `:roundrobin` in place of `RoundRobin()` +- `final_reduction_mode` (default `SerialFinalReduction`). Switch this to `ParallelFinalReduction` or `:parallel` if your reducing operator `op` is significantly slower than the time to `@spawn` and `fetch` tasks (typically tens of microseconds). """ -struct GreedyScheduler{C <: ChunkingMode, S <: Split} <: Scheduler +struct GreedyScheduler{C <: ChunkingMode, S <: Split, FRM <:FinalReductionMode} <: Scheduler ntasks::Int chunking_args::ChunkingArgs{C, S} + final_reduction_mode::FRM - function GreedyScheduler(ntasks::Integer, ca::ChunkingArgs) + function GreedyScheduler(ntasks::Integer, ca::ChunkingArgs, frm=SerialFinalReduction()) ntasks > 0 || throw(ArgumentError("ntasks must be a positive integer")) - return new{chunking_mode(ca), typeof(ca.split)}(ntasks, ca) + return new{chunking_mode(ca), typeof(ca.split), typeof(frm)}(ntasks, ca, frm) end end @@ -353,12 +410,14 @@ function GreedyScheduler(; chunksize::MaybeInteger = NotGiven(), chunking::Bool = false, split::Union{Split, Symbol} = RoundRobin(), - minchunksize::Union{Nothing, Int} = nothing) + minchunksize::Union{Nothing, Int} = nothing, + final_reduction_mode::Union{Symbol,FinalReductionMode} = SerialFinalReduction()) if isgiven(nchunks) || isgiven(chunksize) chunking = true end ca = ChunkingArgs(GreedyScheduler, nchunks, chunksize, split; chunking, minsize=minchunksize) - return GreedyScheduler(ntasks, ca) + frm = FinalReductionMode(final_reduction_mode) + return GreedyScheduler(ntasks, ca, frm) end from_symbol(::Val{:greedy}) = GreedyScheduler chunking_args(sched::GreedyScheduler) = sched.chunking_args @@ -367,9 +426,9 @@ default_nchunks(::Type{GreedyScheduler}) = 10 * nthreads(:default) function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, s::GreedyScheduler) print(io, "GreedyScheduler", "\n") println(io, "├ Num. tasks: ", s.ntasks) - cstr = _chunkingstr(s) - println(io, "├ Chunking: ", cstr) - print(io, "└ Threadpool: default") + println(io, "├ Chunking: ", _chunkingstr(s)) + println(io, "├ Threadpool: default") + print( io, "└ FinalReductionMode: ", FinalReductionMode(s)) end """ diff --git a/test/runtests.jl b/test/runtests.jl index e1b7703..ef72320 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -740,14 +740,16 @@ end """ DynamicScheduler ├ Chunking: fixed count ($nt), split :consecutive - └ Threadpool: default""" + ├ Threadpool: default + └ FinalReductionMode: SerialFinalReduction()""" @test repr( - "text/plain", DynamicScheduler(; chunking = false, threadpool = :interactive)) == + "text/plain", DynamicScheduler(; chunking = false, threadpool = :interactive, final_reduction_mode=:parallel)) == """ DynamicScheduler ├ Chunking: none - └ Threadpool: interactive""" + ├ Threadpool: interactive + └ FinalReductionMode: ParallelFinalReduction()""" @test repr("text/plain", StaticScheduler()) == """StaticScheduler @@ -764,8 +766,9 @@ end """ GreedyScheduler ├ Num. tasks: $nt - ├ Chunking: fixed count ($(10 * nt)), split :roundrobin - └ Threadpool: default""" + ├ Chunking: fixed count ($(10*nt)), split :roundrobin + ├ Threadpool: default + └ FinalReductionMode: SerialFinalReduction()""" end @testset "Boxing detection and error" begin