Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/OhMyThreads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ include("macros.jl")
include("tools.jl")
include("schedulers.jl")
using .Schedulers: Scheduler, DynamicScheduler, StaticScheduler, GreedyScheduler,
SerialScheduler
SerialScheduler, FinalReductionMode, SerialFinalReduction,
ParallelFinalReduction
include("implementation.jl")
include("experimental.jl")

export @tasks, @set, @local, @one_by_one, @only_one, @allow_boxed_captures, @disallow_boxed_captures, @localize
export treduce, tmapreduce, treducemap, tmap, tmap!, tforeach, tcollect
export Scheduler, DynamicScheduler, StaticScheduler, GreedyScheduler, SerialScheduler
export FinalReductionMode, SerialFinalReduction, ParallelFinalReduction

end # module OhMyThreads
47 changes: 36 additions & 11 deletions src/implementation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ using OhMyThreads.Schedulers: chunking_enabled,
nchunks, chunksize, chunksplit, minchunksize, has_chunksplit,
chunking_mode, ChunkingMode, NoChunking,
FixedSize, FixedCount, scheduler_from_symbol, NotGiven,
isgiven
isgiven,
FinalReductionMode,
SerialFinalReduction, ParallelFinalReduction
using Base: @propagate_inbounds
using Base.Threads: nthreads, @threads
using BangBang: append!!
Expand Down Expand Up @@ -86,7 +88,6 @@ function has_multiple_chunks(scheduler, coll)
end
end


function tmapreduce(f, op, Arrs...;
scheduler::MaybeScheduler = NotGiven(),
outputtype::Type = Any,
Expand Down Expand Up @@ -115,6 +116,35 @@ end
treducemap(op, f, A...; kwargs...) = tmapreduce(f, op, A...; kwargs...)


function tree_mapreduce(f, op, v)
if length(v) == 1
f(only(v))
elseif length(v) == 2
op(f(v[1]), f(v[2]))
else
l, r = v[begin:(end-begin)÷2], v[((end-begin)÷2+1):end]
task_r = @spawn tree_mapreduce(f, op, r)
result_l = tree_mapreduce(f, op, l)
op(result_l, fetch(task_r))
end
end

function final_mapreduce(op, tasks, ::SerialFinalReduction; mapreduce_kwargs...)
# Note, calling `promise_task_local` here is only safe because we're assuming that
# Base.mapreduce isn't going to magically try to do multithreading on us...
mapreduce(fetch, promise_task_local(op), tasks; mapreduce_kwargs...)
end
function final_mapreduce(op, tasks, ::ParallelFinalReduction; mapreduce_kwargs...)
if isempty(tasks)
# Note, calling `promise_task_local` here is only safe because we're assuming that
# Base.mapreduce isn't going to magically try to do multithreading on us...
mapreduce(fetch, promise_task_local(op), tasks; mapreduce_kwargs...)
else
tree_mapreduce(fetch, op, tasks; mapreduce_kwargs...)
end
end


# DynamicScheduler: AbstractArray/Generic
function _tmapreduce(f,
op,
Expand All @@ -134,14 +164,13 @@ function _tmapreduce(f,
@spawn threadpool mapreduce(promise_task_local(f), promise_task_local(op),
args...; $mapreduce_kwargs...)
end
mapreduce(fetch, promise_task_local(op), tasks)
else
tasks = map(eachindex(first(Arrs))) do i
args = map(A -> @inbounds(A[i]), Arrs)
@spawn threadpool promise_task_local(f)(args...)
end
mapreduce(fetch, promise_task_local(op), tasks; mapreduce_kwargs...)
end
final_mapreduce(op, tasks, FinalReductionMode(scheduler); mapreduce_kwargs...)
end

# DynamicScheduler: AbstractChunks
Expand All @@ -156,7 +185,7 @@ function _tmapreduce(f,
tasks = map(only(Arrs)) do idcs
@spawn threadpool promise_task_local(f)(idcs)
end
mapreduce(fetch, promise_task_local(op), tasks; mapreduce_kwargs...)
final_mapreduce(op, tasks, FinalReductionMode(scheduler); mapreduce_kwargs...)
end

# StaticScheduler: AbstractArray/Generic
Expand Down Expand Up @@ -284,9 +313,7 @@ function _tmapreduce(f,
true
end
end
# Note, calling `promise_task_local` here is only safe because we're assuming that
# Base.mapreduce isn't going to magically try to do multithreading on us...
mapreduce(fetch, promise_task_local(op), filtered_tasks; mapreduce_kwargs...)
final_mapreduce(op, filtered_tasks, FinalReductionMode(scheduler); mapreduce_kwargs...)
end

# GreedyScheduler w/ chunking
Expand Down Expand Up @@ -332,9 +359,7 @@ function _tmapreduce(f,
true
end
end
# Note, calling `promise_task_local` here is only safe because we're assuming that
# Base.mapreduce isn't going to magically try to do multithreading on us...
mapreduce(fetch, promise_task_local(op), filtered_tasks; mapreduce_kwargs...)
final_mapreduce(op, filtered_tasks, FinalReductionMode(scheduler); mapreduce_kwargs...)
end

function check_all_have_same_indices(Arrs)
Expand Down
89 changes: 74 additions & 15 deletions src/schedulers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,57 @@ kind of scheduler.
function default_nchunks end
default_nchunks(::Type{<:Scheduler}) = nthreads(:default)



tree_str = raw"""
```
t1 t2 t3 t4 t5 t6
\ | | / | /
\ | | / | /
op op op
\ / /
\ / /
op /
\ /
op
```
"""

"""
FinalReductionMode

A trait type to decide how the final reduction is performed. Essentially,
OhMyThreads.jl will turn a `tmapreduce(f, op, v)` call into something of
the form
```julia
tasks = map(chunks(v; chunking_kwargs...)) do chunk
@spawn mapreduce(f, op, chunk)
end
final_reduction(op, tasks, ReductionMode)
```
where the options for `ReductionMode` are currently

* `SerialFinalReduction` is the default option that should be preferred whenever `op` is not the bottleneck in your reduction. In this mode, we use a simple `mapreduce` over the tasks vector, fetching each one, i.e.
```julia
function final_reduction(op, tasks, ::SerialFinalReduction)
mapreduce(fetch, op, tasks)
end
```

* `ParallelFinalReduction` should be opted into when `op` takes a long time relative to the time it takes to `@spawn` and `fetch` tasks (typically tens of microseconds). In this mode, the vector of tasks is split up and `op` is applied in parallel using a recursive tree-based approach.
$tree_str
"""
abstract type FinalReductionMode end
struct SerialFinalReduction <: FinalReductionMode end
struct ParallelFinalReduction <: FinalReductionMode end

FinalReductionMode(s::Scheduler) = s.final_reduction_mode

FinalReductionMode(s::Symbol) = FinalReductionMode(Val(s))
FinalReductionMode(::Val{:serial}) = SerialFinalReduction()
FinalReductionMode(::Val{:parallel}) = ParallelFinalReduction()
FinalReductionMode(m::FinalReductionMode) = m

"""
DynamicScheduler (aka :dynamic)

Expand Down Expand Up @@ -202,16 +253,17 @@ with other multithreaded code.
- `threadpool::Symbol` (default `:default`):
* Possible options are `:default` and `:interactive`.
* The high-priority pool `:interactive` should be used very carefully since tasks on this threadpool should not be allowed to run for a long time without `yield`ing as it can interfere with [heartbeat](https://en.wikipedia.org/wiki/Heartbeat_(computing)) processes.
- `final_reduction_mode` (default `SerialFinalReduction`). Switch this to `ParallelFinalReduction` or `:parallel` if your reducing operator `op` is significantly slower than the time to `@spawn` and `fetch` tasks (typically tens of microseconds).
"""
struct DynamicScheduler{C <: ChunkingMode, S <: Split} <: Scheduler
struct DynamicScheduler{C <: ChunkingMode, S <: Split, FRM <: FinalReductionMode} <: Scheduler
threadpool::Symbol
chunking_args::ChunkingArgs{C, S}

function DynamicScheduler(threadpool::Symbol, ca::ChunkingArgs)
final_reduction_mode::FRM
function DynamicScheduler(threadpool::Symbol, ca::ChunkingArgs, frm=SerialFinalReduction())
if !(threadpool in (:default, :interactive))
throw(ArgumentError("threadpool must be either :default or :interactive"))
end
new{chunking_mode(ca), typeof(ca.split)}(threadpool, ca)
new{chunking_mode(ca), typeof(ca.split), typeof(frm)}(threadpool, ca, frm)
end
end

Expand All @@ -222,15 +274,17 @@ function DynamicScheduler(;
chunksize::MaybeInteger = NotGiven(),
chunking::Bool = true,
split::Union{Split, Symbol} = Consecutive(),
minchunksize::Union{Nothing, Int}=nothing)
minchunksize::Union{Nothing, Int}=nothing,
final_reduction_mode::Union{Symbol, FinalReductionMode}=SerialFinalReduction())
if isgiven(ntasks)
if isgiven(nchunks)
throw(ArgumentError("For the dynamic scheduler, nchunks and ntasks are aliases and only one may be provided"))
end
nchunks = ntasks
end
ca = ChunkingArgs(DynamicScheduler, nchunks, chunksize, split; chunking, minsize=minchunksize)
return DynamicScheduler(threadpool, ca)
frm = FinalReductionMode(final_reduction_mode)
return DynamicScheduler(threadpool, ca, frm)
end
from_symbol(::Val{:dynamic}) = DynamicScheduler
chunking_args(sched::DynamicScheduler) = sched.chunking_args
Expand All @@ -239,7 +293,8 @@ function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, s::DynamicScheduler
print(io, "DynamicScheduler", "\n")
cstr = _chunkingstr(s.chunking_args)
println(io, "├ Chunking: ", cstr)
print(io, "└ Threadpool: ", s.threadpool)
println(io, "├ Threadpool: ", s.threadpool)
print(io, "└ FinalReductionMode: ", FinalReductionMode(s))
end

"""
Expand Down Expand Up @@ -336,14 +391,16 @@ some additional overhead.
- `split::Union{Symbol, OhMyThreads.Split}` (default `OhMyThreads.RoundRobin()`):
* Determines how the collection is divided into chunks (if chunking=true).
* See [ChunkSplitters.jl](https://github.com/JuliaFolds2/ChunkSplitters.jl) for more details and available options. We also allow users to pass `:consecutive` in place of `Consecutive()`, and `:roundrobin` in place of `RoundRobin()`
- `final_reduction_mode` (default `SerialFinalReduction`). Switch this to `ParallelFinalReduction` or `:parallel` if your reducing operator `op` is significantly slower than the time to `@spawn` and `fetch` tasks (typically tens of microseconds).
"""
struct GreedyScheduler{C <: ChunkingMode, S <: Split} <: Scheduler
struct GreedyScheduler{C <: ChunkingMode, S <: Split, FRM <:FinalReductionMode} <: Scheduler
ntasks::Int
chunking_args::ChunkingArgs{C, S}
final_reduction_mode::FRM

function GreedyScheduler(ntasks::Integer, ca::ChunkingArgs)
function GreedyScheduler(ntasks::Integer, ca::ChunkingArgs, frm=SerialFinalReduction())
ntasks > 0 || throw(ArgumentError("ntasks must be a positive integer"))
return new{chunking_mode(ca), typeof(ca.split)}(ntasks, ca)
return new{chunking_mode(ca), typeof(ca.split), typeof(frm)}(ntasks, ca, frm)
end
end

Expand All @@ -353,12 +410,14 @@ function GreedyScheduler(;
chunksize::MaybeInteger = NotGiven(),
chunking::Bool = false,
split::Union{Split, Symbol} = RoundRobin(),
minchunksize::Union{Nothing, Int} = nothing)
minchunksize::Union{Nothing, Int} = nothing,
final_reduction_mode::Union{Symbol,FinalReductionMode} = SerialFinalReduction())
if isgiven(nchunks) || isgiven(chunksize)
chunking = true
end
ca = ChunkingArgs(GreedyScheduler, nchunks, chunksize, split; chunking, minsize=minchunksize)
return GreedyScheduler(ntasks, ca)
frm = FinalReductionMode(final_reduction_mode)
return GreedyScheduler(ntasks, ca, frm)
end
from_symbol(::Val{:greedy}) = GreedyScheduler
chunking_args(sched::GreedyScheduler) = sched.chunking_args
Expand All @@ -367,9 +426,9 @@ default_nchunks(::Type{GreedyScheduler}) = 10 * nthreads(:default)
function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, s::GreedyScheduler)
print(io, "GreedyScheduler", "\n")
println(io, "├ Num. tasks: ", s.ntasks)
cstr = _chunkingstr(s)
println(io, "├ Chunking: ", cstr)
print(io, "└ Threadpool: default")
println(io, "├ Chunking: ", _chunkingstr(s))
println(io, "├ Threadpool: default")
print( io, "└ FinalReductionMode: ", FinalReductionMode(s))
end

"""
Expand Down
13 changes: 8 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -740,14 +740,16 @@ end
"""
DynamicScheduler
├ Chunking: fixed count ($nt), split :consecutive
└ Threadpool: default"""
├ Threadpool: default
└ FinalReductionMode: SerialFinalReduction()"""

@test repr(
"text/plain", DynamicScheduler(; chunking = false, threadpool = :interactive)) ==
"text/plain", DynamicScheduler(; chunking = false, threadpool = :interactive, final_reduction_mode=:parallel)) ==
"""
DynamicScheduler
├ Chunking: none
└ Threadpool: interactive"""
├ Threadpool: interactive
└ FinalReductionMode: ParallelFinalReduction()"""

@test repr("text/plain", StaticScheduler()) ==
"""StaticScheduler
Expand All @@ -764,8 +766,9 @@ end
"""
GreedyScheduler
├ Num. tasks: $nt
├ Chunking: fixed count ($(10 * nt)), split :roundrobin
└ Threadpool: default"""
├ Chunking: fixed count ($(10*nt)), split :roundrobin
├ Threadpool: default
└ FinalReductionMode: SerialFinalReduction()"""
end

@testset "Boxing detection and error" begin
Expand Down
Loading