@@ -220,6 +220,73 @@ function _threadsfor(iter, lbody, schedule)
220220 end
221221end
222222
223+ function _threadsfor_comprehension (gen:: Expr , schedule)
224+ @assert gen. head === :generator
225+
226+ body = gen. args[1 ]
227+ iter_or_filter = gen. args[2 ]
228+
229+ # Handle filtered vs non-filtered comprehensions
230+ if isa (iter_or_filter, Expr) && iter_or_filter. head === :filter
231+ condition = iter_or_filter. args[1 ]
232+ iterator = iter_or_filter. args[2 ]
233+ return _threadsfor_filtered_comprehension (body, iterator, condition, schedule)
234+ else
235+ iterator = iter_or_filter
236+ # Use filtered comprehension with `true` condition for non-filtered case
237+ return _threadsfor_filtered_comprehension (body, iterator, true , schedule)
238+ end
239+ end
240+
241+ function _threadsfor_filtered_comprehension (body, iterator, condition, schedule)
242+ lidx = iterator. args[1 ] # index variable
243+ range = iterator. args[2 ] # range/iterable
244+ esc_range = esc (range)
245+ esc_body = esc (body)
246+ esc_condition = esc (condition)
247+
248+ if schedule === :greedy
249+ quote
250+ local ch = Channel {eltype($esc_range)} (0 ,spawn= true ) do ch
251+ for item in $ esc_range
252+ put! (ch, item)
253+ end
254+ end
255+ local thread_result_storage = Vector {Vector{Any}} (undef, threadpoolsize ())
256+ function threadsfor_fun (tid)
257+ local_results = Any[]
258+ for item in ch
259+ local $ (esc (lidx)) = item
260+ if $ esc_condition
261+ push! (local_results, $ esc_body)
262+ end
263+ end
264+ thread_result_storage[tid] = local_results
265+ end
266+ threading_run (threadsfor_fun, false )
267+ # Collect results after threading_run
268+ assigned_results = [thread_result_storage[i] for i in 1 : threadpoolsize () if isassigned (thread_result_storage, i)]
269+ vcat (assigned_results... )
270+ end
271+ else
272+ func = default_filtered_comprehension_func (esc_range, lidx, esc_body, esc_condition)
273+ quote
274+ local threadsfor_fun
275+ local result
276+ $ func
277+ if $ (schedule === :dynamic || schedule === :default )
278+ threading_run (threadsfor_fun, false )
279+ elseif ccall (:jl_in_threaded_region , Cint, ()) != 0 # :static
280+ error (" `@threads :static` cannot be used concurrently or nested" )
281+ else # :static
282+ threading_run (threadsfor_fun, true )
283+ end
284+ # Process result after threading_run
285+ vcat (result... )
286+ end
287+ end
288+ end
289+
223290function greedy_func (itr, lidx, lbody)
224291 quote
225292 let c = Channel {eltype($itr)} (0 ,spawn= true ) do ch
@@ -237,39 +304,47 @@ function greedy_func(itr, lidx, lbody)
237304 end
238305end
239306
307+ # Helper function to generate work distribution code
308+ function _work_distribution_code ()
309+ quote
310+ r = range # Load into local variable
311+ lenr = length (r)
312+ # divide loop iterations among threads
313+ if onethread
314+ tid = 1
315+ len, rem = lenr, 0
316+ else
317+ len, rem = divrem (lenr, threadpoolsize ())
318+ end
319+ # not enough iterations for all the threads?
320+ if len == 0
321+ if tid > rem
322+ return
323+ end
324+ len, rem = 1 , 0
325+ end
326+ # compute this thread's iterations
327+ f = firstindex (r) + ((tid- 1 ) * len)
328+ l = f + len - 1
329+ # distribute remaining iterations evenly
330+ if rem > 0
331+ if tid <= rem
332+ f = f + (tid- 1 )
333+ l = l + tid
334+ else
335+ f = f + rem
336+ l = l + rem
337+ end
338+ end
339+ end
340+ end
341+
240342function default_func (itr, lidx, lbody)
343+ work_dist = _work_distribution_code ()
241344 quote
242345 let range = $ itr
243346 function threadsfor_fun (tid = 1 ; onethread = false )
244- r = range # Load into local variable
245- lenr = length (r)
246- # divide loop iterations among threads
247- if onethread
248- tid = 1
249- len, rem = lenr, 0
250- else
251- len, rem = divrem (lenr, threadpoolsize ())
252- end
253- # not enough iterations for all the threads?
254- if len == 0
255- if tid > rem
256- return
257- end
258- len, rem = 1 , 0
259- end
260- # compute this thread's iterations
261- f = firstindex (r) + ((tid- 1 ) * len)
262- l = f + len - 1
263- # distribute remaining iterations evenly
264- if rem > 0
265- if tid <= rem
266- f = f + (tid- 1 )
267- l = l + tid
268- else
269- f = f + rem
270- l = l + rem
271- end
272- end
347+ $ work_dist
273348 # run this thread's iterations
274349 for i = f: l
275350 local $ (esc (lidx)) = @inbounds r[i]
@@ -280,13 +355,46 @@ function default_func(itr, lidx, lbody)
280355 end
281356end
282357
358+ function default_filtered_comprehension_func (itr, lidx, body, condition)
359+ work_dist = _work_distribution_code ()
360+ quote
361+ let range = $ itr
362+ local thread_results = Vector {Vector{Any}} (undef, threadpoolsize ())
363+ # Initialize all result vectors to empty
364+ for i in 1 : threadpoolsize ()
365+ thread_results[i] = Any[]
366+ end
367+
368+ function threadsfor_fun (tid = 1 ; onethread = false )
369+ $ work_dist
370+ # run this thread's iterations with filtering
371+ local_results = Any[]
372+ for i = f: l
373+ local $ (esc (lidx)) = @inbounds r[i]
374+ if $ condition
375+ push! (local_results, $ body)
376+ end
377+ end
378+ thread_results[tid] = local_results
379+ end
380+
381+ result = thread_results # This will be populated by threading_run
382+ end
383+ end
384+ end
385+
283386"""
284387 Threads.@threads [schedule] for ... end
388+ Threads.@threads [schedule] [expr for ... end]
285389
286- A macro to execute a `for` loop in parallel. The iteration space is distributed to
390+ A macro to execute a `for` loop or array comprehension in parallel. The iteration space is distributed to
287391coarse-grained tasks. This policy can be specified by the `schedule` argument. The
288392execution of the loop waits for the evaluation of all iterations.
289393
394+ For `for` loops, the macro executes the loop body in parallel but does not return a value.
395+ For array comprehensions, the macro executes the comprehension in parallel and returns
396+ the collected results as an array.
397+
290398See also: [`@spawn`](@ref Threads.@spawn) and
291399`pmap` in [`Distributed`](@ref man-distributed).
292400
@@ -371,6 +479,8 @@ thread other than 1.
371479
372480## Examples
373481
482+ ### For loops
483+
374484To illustrate of the different scheduling strategies, consider the following function
375485`busywait` containing a non-yielding timed loop that runs for a given number of seconds.
376486
@@ -400,6 +510,38 @@ julia> @time begin
400510
401511The `:dynamic` example takes 2 seconds since one of the non-occupied threads is able
402512to run two of the 1-second iterations to complete the for loop.
513+
514+ ### Array comprehensions
515+
516+ The `@threads` macro also supports array comprehensions, which return the collected results:
517+
518+ ```julia-repl
519+ julia> Threads.@threads [i^2 for i in 1:5] # Simple comprehension
520+ 5-element Vector{Int64}:
521+ 1
522+ 4
523+ 9
524+ 16
525+ 25
526+
527+ julia> Threads.@threads [i^2 for i in 1:5 if iseven(i)] # Filtered comprehension
528+ 2-element Vector{Int64}:
529+ 4
530+ 16
531+ ```
532+
533+ When the iterator doesn't have a known length, such as a channel, the `:greedy` scheduling
534+ option can be used, but note that the order of the results is not guaranteed.
535+ ```julia-repl
536+ julia> c = Channel(5, spawn=true) do ch
537+ foreach(i -> put!(ch, i), 1:5)
538+ end;
539+
540+ julia> Threads.@threads :greedy [i^2 for i in c if iseven(i)]
541+ 2-element Vector{Any}:
542+ 16
543+ 4
544+ ```
403545"""
404546macro threads (args... )
405547 na = length (args)
@@ -420,13 +562,18 @@ macro threads(args...)
420562 else
421563 throw (ArgumentError (" wrong number of arguments in @threads" ))
422564 end
423- if ! (isa (ex, Expr) && ex. head === :for )
424- throw (ArgumentError (" @threads requires a `for` loop expression" ))
425- end
426- if ! (ex. args[1 ] isa Expr && ex. args[1 ]. head === :(= ))
427- throw (ArgumentError (" nested outer loops are not currently supported by @threads" ))
565+ if isa (ex, Expr) && ex. head === :comprehension
566+ # Handle array comprehensions
567+ return _threadsfor_comprehension (ex. args[1 ], sched)
568+ elseif isa (ex, Expr) && ex. head === :for
569+ # Handle for loops
570+ if ! (ex. args[1 ] isa Expr && ex. args[1 ]. head === :(= ))
571+ throw (ArgumentError (" nested outer loops are not currently supported by @threads" ))
572+ end
573+ return _threadsfor (ex. args[1 ], ex. args[2 ], sched)
574+ else
575+ throw (ArgumentError (" @threads requires a `for` loop or comprehension expression" ))
428576 end
429- return _threadsfor (ex. args[1 ], ex. args[2 ], sched)
430577end
431578
432579function _spawn_set_thrpool (t:: Task , tp:: Symbol )
0 commit comments