Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ Multi-threading changes
the first time it is called, and then always return the same result value of type `T`
every subsequent time afterwards. There are also `OncePerThread{T}` and `OncePerTask{T}` types for
similar usage with threads or tasks. ([#TBD])
* `Threads.@threads` now supports array comprehensions with syntax like `@threads [f(i) for i in 1:n]`,
filtered comprehensions like `@threads [f(i) for i in 1:n if condition(i)]`, and multi-dimensional
comprehensions like `@threads [f(i,j) for i in 1:n, j in 1:m]` (preserves dimensions). All scheduling
options (`:static`, `:dynamic`, `:greedy`) are supported. Results preserve element order for
`:static` and `:dynamic` scheduling, while `:greedy` returns elements in arbitrary order for
better performance with non-uniform workloads. Non-indexable iterators are also supported. ([#59019])

Build system changes
--------------------
Expand Down
319 changes: 275 additions & 44 deletions base/threadingconstructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,19 @@ function threading_run(fun, static)
end
end

# Helper to generate threading run code with schedule checking
function _threading_run_expr(schedule)
quote
if $(schedule === :greedy || schedule === :dynamic || schedule === :default)
threading_run(threadsfor_fun, false)
elseif ccall(:jl_in_threaded_region, Cint, ()) != 0 # :static
error("`@threads :static` cannot be used concurrently or nested")
else # :static
threading_run(threadsfor_fun, true)
end
end
end

function _threadsfor(iter, lbody, schedule)
lidx = iter.args[1] # index
range = iter.args[2]
Expand All @@ -209,17 +222,121 @@ function _threadsfor(iter, lbody, schedule)
quote
local threadsfor_fun
$func
if $(schedule === :greedy || schedule === :dynamic || schedule === :default)
threading_run(threadsfor_fun, false)
elseif ccall(:jl_in_threaded_region, Cint, ()) != 0 # :static
error("`@threads :static` cannot be used concurrently or nested")
else # :static
threading_run(threadsfor_fun, true)
end
$(_threading_run_expr(schedule))
nothing
end
end

function _threadsfor_comprehension(gen::Expr, schedule)
@assert gen.head === :generator

body = gen.args[1]

# Check if the second arg is a filter (handles both single and multi-loop with filters)
iter_or_filter = gen.args[2]
if isa(iter_or_filter, Expr) && iter_or_filter.head === :filter
# Has a filter
condition = iter_or_filter.args[1]
iterators = iter_or_filter.args[2:end] # Rest are iterators

if length(iterators) == 1
# Single loop with filter
return _threadsfor_single_iterator(body, iterators[1], condition, schedule)
else
# Multiple loops with filter: [expr for i in iter1, j in iter2, ... if cond]
vars = [iter.args[1] for iter in iterators]
ranges = [iter.args[2] for iter in iterators]

# Create product iterator and synthetic iterator with assignments
tuple_var = gensym("iter_tuple")
assignments = [:($(vars[i]) = $(tuple_var)[$i]) for i in 1:length(vars)]
new_body = quote
$(assignments...)
$(body)
end
# Also need to apply assignments in condition
new_condition = quote
$(assignments...)
$(condition)
end

product_expr = :(Iterators.product($(ranges...)))
synthetic_iter = :($(tuple_var) = $(product_expr))

return _threadsfor_single_iterator(new_body, synthetic_iter, new_condition, schedule)
end
elseif length(gen.args) > 2
# Multiple loops without filter: [expr for i in iter1, j in iter2, ...]
iterators = gen.args[2:end]

vars = [iter.args[1] for iter in iterators]
ranges = [iter.args[2] for iter in iterators]

# Create product iterator and synthetic iterator with assignments
tuple_var = gensym("iter_tuple")
assignments = [:($(vars[i]) = $(tuple_var)[$i]) for i in 1:length(vars)]
new_body = quote
$(assignments...)
$(body)
end

product_expr = :(Iterators.product($(ranges...)))
synthetic_iter = :($(tuple_var) = $(product_expr))

# Need to track dimensions for reshaping - calculate lengths of each range
dims_expr = :(tuple($([:(length($(esc(r)))) for r in ranges]...)))

return _threadsfor_single_iterator(new_body, synthetic_iter, true, schedule, dims_expr)
else
# Single iterator without filter
iterator = iter_or_filter
return _threadsfor_single_iterator(body, iterator, true, schedule)
end
end

function _threadsfor_single_iterator(body, iterator, condition, schedule, dims=nothing)
lidx = iterator.args[1]
range = iterator.args[2]
esc_range = esc(range)
esc_lidx = esc(lidx)
esc_body = esc(body)
esc_condition = condition === true ? true : esc(condition)

func = if schedule === :greedy
greedy_comprehension_func(esc_range, esc_lidx, esc_body, esc_condition)
else
default_comprehension_func(esc_range, esc_lidx, esc_body, esc_condition)
end

# Collect (index, result) pairs from channel and sort to preserve order
result_expr = quote
close(result_channel)
pairs = collect(result_channel)
if isempty(pairs)
[]
else
sort!(pairs, by=first)
[p[2] for p in pairs]
end
end

# If dims is provided, reshape the result to match original comprehension dimensions
if dims !== nothing
result_expr = quote
let flat_result = $result_expr
reshape(flat_result, $(dims))
end
end
end

quote
local threadsfor_fun
local result_channel = $func
$(_threading_run_expr(schedule))
$result_expr
end
end

function greedy_func(itr, lidx, lbody)
quote
let c = Channel{eltype($itr)}(threadpoolsize(), spawn=true) do ch
Expand All @@ -237,41 +354,70 @@ function greedy_func(itr, lidx, lbody)
end
end

function default_func(itr, lidx, lbody)
function greedy_comprehension_func(itr, esc_lidx, esc_body, esc_condition)
quote
let range = $itr
function threadsfor_fun(tid = 1; onethread = false)
r = range # Load into local variable
lenr = length(r)
# divide loop iterations among threads
if onethread
tid = 1
len, rem = lenr, 0
else
len, rem = divrem(lenr, threadpoolsize())
let c = Channel(threadpoolsize(), spawn=true) do ch
for (idx, item) in enumerate($itr)
put!(ch, (idx, item))
end
# not enough iterations for all the threads?
if len == 0
if tid > rem
return
end
result_channel = Channel(Inf)

function threadsfor_fun(tid)
for (idx, item) in c
local $esc_lidx = item
if $esc_condition
put!(result_channel, (idx, $esc_body))
end
len, rem = 1, 0
end
# compute this thread's iterations
f = firstindex(r) + ((tid-1) * len)
l = f + len - 1
# distribute remaining iterations evenly
if rem > 0
if tid <= rem
f = f + (tid-1)
l = l + tid
else
f = f + rem
l = l + rem
end
end
result_channel # Return channel so results can be collected after threading_run
end
end
end

# Helper function to generate work distribution code
function _work_distribution_code()
quote
r = range # Load into local variable
lenr = length(r)
# divide loop iterations among threads
if onethread
tid = 1
len, rem = lenr, 0
else
len, rem = divrem(lenr, threadpoolsize())
end
# not enough iterations for all the threads?
if len == 0
if tid > rem
return
end
# run this thread's iterations
for i = f:l
len, rem = 1, 0
end
# compute this thread's iterations
loop_first = firstindex(r) + ((tid-1) * len)
loop_last = loop_first + len - 1
# distribute remaining iterations evenly
if rem > 0
if tid <= rem
loop_first = loop_first + (tid-1)
loop_last = loop_last + tid
else
loop_first = loop_first + rem
loop_last = loop_last + rem
end
end
end
end

function default_func(itr, lidx, lbody)
work_dist = _work_distribution_code()
quote
let range = $itr
function threadsfor_fun(tid = 1; onethread = false)
$work_dist
for i = loop_first:loop_last
local $(esc(lidx)) = @inbounds r[i]
$(esc(lbody))
end
Expand All @@ -280,13 +426,40 @@ function default_func(itr, lidx, lbody)
end
end

function default_comprehension_func(itr, esc_lidx, esc_body, esc_condition)
work_dist = _work_distribution_code()
quote
let iter = $itr
# Collect non-indexable iterators (like ProductIterator)
range = (iter isa AbstractArray || hasmethod(getindex, Tuple{typeof(iter), Int})) ? iter : collect(iter)
result_channel = Channel(Inf)

function threadsfor_fun(tid = 1; onethread = false)
$work_dist
for i = loop_first:loop_last
local $esc_lidx = @inbounds r[i]
if $esc_condition
put!(result_channel, (i, $esc_body))
end
end
end
result_channel # Return channel so results can be collected after threading_run
end
end
end

"""
Threads.@threads [schedule] for ... end
Threads.@threads [schedule] [expr for ... end]

A macro to execute a `for` loop in parallel. The iteration space is distributed to
A macro to execute a `for` loop or array comprehension in parallel. The iteration space is distributed to
coarse-grained tasks. This policy can be specified by the `schedule` argument. The
execution of the loop waits for the evaluation of all iterations.

For `for` loops, the macro executes the loop body in parallel but does not return a value.
For array comprehensions, the macro executes the comprehension in parallel and returns
the collected results as an array.

Tasks spawned by `@threads` are scheduled on the `:default` threadpool. This means that
`@threads` will not use threads from the `:interactive` threadpool, even if called from
the main thread or from a task in the interactive pool. The `:default` threadpool is
Expand Down Expand Up @@ -377,6 +550,8 @@ thread other than 1.

## Examples

### For loops

To illustrate of the different scheduling strategies, consider the following function
`busywait` containing a non-yielding timed loop that runs for a given number of seconds.

Expand Down Expand Up @@ -406,6 +581,57 @@ julia> @time begin

The `:dynamic` example takes 2 seconds since one of the non-occupied threads is able
to run two of the 1-second iterations to complete the for loop.

### Array comprehensions

The `@threads` macro also supports array comprehensions, which return the collected results.
Array comprehensions preserve element order for all scheduling options. Multi-dimensional
comprehensions preserve the dimensions of the original comprehension (e.g., `[f(i,j) for i in 1:n, j in 1:m]`
returns an `n×m` matrix).

```julia-repl
julia> Threads.@threads [i^2 for i in 1:5] # Simple comprehension
5-element Vector{Int64}:
1
4
9
16
25

julia> Threads.@threads [i^2 for i in 1:5 if iseven(i)] # Filtered comprehension
2-element Vector{Int64}:
4
16

julia> Threads.@threads [i + j for i in 1:3, j in 1:3] # Multiple loops
3×3 Matrix{Int64}:
2 3 4
3 4 5
4 5 6
```

When the iterator doesn't have a known length, such as a channel, the `:greedy` scheduling
option can be used.
```julia-repl
julia> c = Channel(5, spawn=true) do ch
foreach(i -> put!(ch, i), 1:5)
end;

julia> Threads.@threads :greedy [i^2 for i in c if iseven(i)]
2-element Vector{Int64}:
4
16

julia> # Non-indexable iterators are also supported
Threads.@threads [i for i in Iterators.flatten([1:3, 4:6])]
6-element Vector{Int64}:
1
2
3
4
5
6
```
"""
macro threads(args...)
na = length(args)
Expand All @@ -426,13 +652,18 @@ macro threads(args...)
else
throw(ArgumentError("wrong number of arguments in @threads"))
end
if !(isa(ex, Expr) && ex.head === :for)
throw(ArgumentError("@threads requires a `for` loop expression"))
end
if !(ex.args[1] isa Expr && ex.args[1].head === :(=))
throw(ArgumentError("nested outer loops are not currently supported by @threads"))
if isa(ex, Expr) && ex.head === :comprehension
# Handle array comprehensions
return _threadsfor_comprehension(ex.args[1], sched)
elseif isa(ex, Expr) && ex.head === :for
# Handle for loops
if !(ex.args[1] isa Expr && ex.args[1].head === :(=))
throw(ArgumentError("nested outer loops are not currently supported by @threads"))
end
return _threadsfor(ex.args[1], ex.args[2], sched)
else
throw(ArgumentError("@threads requires a `for` loop or comprehension expression"))
end
return _threadsfor(ex.args[1], ex.args[2], sched)
end

function _spawn_set_thrpool(t::Task, tp::Symbol)
Expand Down
Loading