@@ -220,6 +220,117 @@ function _threadsfor(iter, lbody, schedule)
220220 end
221221end
222222
223+ function _threadsfor_comprehension (gen:: Expr , schedule)
224+ @assert gen. head === :generator
225+
226+ body = gen. args[1 ]
227+ iter_or_filter = gen. args[2 ]
228+
229+ # Handle filtered vs non-filtered comprehensions
230+ if isa (iter_or_filter, Expr) && iter_or_filter. head === :filter
231+ condition = iter_or_filter. args[1 ]
232+ iterator = iter_or_filter. args[2 ]
233+ return _threadsfor_filtered_comprehension (body, iterator, condition, schedule)
234+ else
235+ iterator = iter_or_filter
236+ return _threadsfor_simple_comprehension (body, iterator, schedule)
237+ end
238+ end
239+
240+ function _threadsfor_simple_comprehension (body, iterator, schedule)
241+ lidx = iterator. args[1 ] # index variable
242+ range = iterator. args[2 ] # range/iterable
243+ esc_range = esc (range)
244+ esc_body = esc (body)
245+
246+ if schedule === :greedy
247+ quote
248+ local ch = Channel {eltype($esc_range)} (0 ,spawn= true ) do ch
249+ for item in $ esc_range
250+ put! (ch, item)
251+ end
252+ end
253+ local thread_result_storage = Vector {Vector{Any}} (undef, threadpoolsize ())
254+ function threadsfor_fun (tid)
255+ local_results = Any[]
256+ for item in ch
257+ local $ (esc (lidx)) = item
258+ push! (local_results, $ esc_body)
259+ end
260+ thread_result_storage[tid] = local_results
261+ end
262+ threading_run (threadsfor_fun, false )
263+ # Collect results after threading_run
264+ assigned_results = [thread_result_storage[i] for i in 1 : threadpoolsize () if isassigned (thread_result_storage, i)]
265+ vcat (assigned_results... )
266+ end
267+ else
268+ func = default_comprehension_func (esc_range, lidx, esc_body)
269+ quote
270+ local threadsfor_fun
271+ local result
272+ $ func
273+ if $ (schedule === :dynamic || schedule === :default )
274+ threading_run (threadsfor_fun, false )
275+ elseif ccall (:jl_in_threaded_region , Cint, ()) != 0 # :static
276+ error (" `@threads :static` cannot be used concurrently or nested" )
277+ else # :static
278+ threading_run (threadsfor_fun, true )
279+ end
280+ result
281+ end
282+ end
283+ end
284+
285+ function _threadsfor_filtered_comprehension (body, iterator, condition, schedule)
286+ lidx = iterator. args[1 ] # index variable
287+ range = iterator. args[2 ] # range/iterable
288+ esc_range = esc (range)
289+ esc_body = esc (body)
290+ esc_condition = esc (condition)
291+
292+ if schedule === :greedy
293+ quote
294+ local ch = Channel {eltype($esc_range)} (0 ,spawn= true ) do ch
295+ for item in $ esc_range
296+ put! (ch, item)
297+ end
298+ end
299+ local thread_result_storage = Vector {Vector{Any}} (undef, threadpoolsize ())
300+ function threadsfor_fun (tid)
301+ local_results = Any[]
302+ for item in ch
303+ local $ (esc (lidx)) = item
304+ if $ esc_condition
305+ push! (local_results, $ esc_body)
306+ end
307+ end
308+ thread_result_storage[tid] = local_results
309+ end
310+ threading_run (threadsfor_fun, false )
311+ # Collect results after threading_run
312+ assigned_results = [thread_result_storage[i] for i in 1 : threadpoolsize () if isassigned (thread_result_storage, i)]
313+ vcat (assigned_results... )
314+ end
315+ else
316+ func = default_filtered_comprehension_func (esc_range, lidx, esc_body, esc_condition)
317+ quote
318+ local threadsfor_fun
319+ local result
320+ $ func
321+ if $ (schedule === :dynamic || schedule === :default )
322+ threading_run (threadsfor_fun, false )
323+ elseif ccall (:jl_in_threaded_region , Cint, ()) != 0 # :static
324+ error (" `@threads :static` cannot be used concurrently or nested" )
325+ else # :static
326+ threading_run (threadsfor_fun, true )
327+ end
328+ # Process result after threading_run
329+ vcat (result... )
330+ end
331+ end
332+ end
333+
223334function greedy_func (itr, lidx, lbody)
224335 quote
225336 let c = Channel {eltype($itr)} (0 ,spawn= true ) do ch
@@ -237,39 +348,47 @@ function greedy_func(itr, lidx, lbody)
237348 end
238349end
239350
351+ # Helper function to generate work distribution code
352+ function _work_distribution_code ()
353+ quote
354+ r = range # Load into local variable
355+ lenr = length (r)
356+ # divide loop iterations among threads
357+ if onethread
358+ tid = 1
359+ len, rem = lenr, 0
360+ else
361+ len, rem = divrem (lenr, threadpoolsize ())
362+ end
363+ # not enough iterations for all the threads?
364+ if len == 0
365+ if tid > rem
366+ return
367+ end
368+ len, rem = 1 , 0
369+ end
370+ # compute this thread's iterations
371+ f = firstindex (r) + ((tid- 1 ) * len)
372+ l = f + len - 1
373+ # distribute remaining iterations evenly
374+ if rem > 0
375+ if tid <= rem
376+ f = f + (tid- 1 )
377+ l = l + tid
378+ else
379+ f = f + rem
380+ l = l + rem
381+ end
382+ end
383+ end
384+ end
385+
240386function default_func (itr, lidx, lbody)
387+ work_dist = _work_distribution_code ()
241388 quote
242389 let range = $ itr
243390 function threadsfor_fun (tid = 1 ; onethread = false )
244- r = range # Load into local variable
245- lenr = length (r)
246- # divide loop iterations among threads
247- if onethread
248- tid = 1
249- len, rem = lenr, 0
250- else
251- len, rem = divrem (lenr, threadpoolsize ())
252- end
253- # not enough iterations for all the threads?
254- if len == 0
255- if tid > rem
256- return
257- end
258- len, rem = 1 , 0
259- end
260- # compute this thread's iterations
261- f = firstindex (r) + ((tid- 1 ) * len)
262- l = f + len - 1
263- # distribute remaining iterations evenly
264- if rem > 0
265- if tid <= rem
266- f = f + (tid- 1 )
267- l = l + tid
268- else
269- f = f + rem
270- l = l + rem
271- end
272- end
391+ $ work_dist
273392 # run this thread's iterations
274393 for i = f: l
275394 local $ (esc (lidx)) = @inbounds r[i]
@@ -280,13 +399,68 @@ function default_func(itr, lidx, lbody)
280399 end
281400end
282401
402+ function default_comprehension_func (itr, lidx, body)
403+ work_dist = _work_distribution_code ()
404+ quote
405+ result = let range = $ itr
406+ lenr = length (range)
407+ # Pre-allocate result array with the correct size
408+ local result_array = Vector {Any} (undef, lenr)
409+
410+ function threadsfor_fun (tid = 1 ; onethread = false )
411+ $ work_dist
412+ # run this thread's iterations and store directly in result_array
413+ for i = f: l
414+ local $ (esc (lidx)) = @inbounds r[i]
415+ result_array[i] = $ body
416+ end
417+ end
418+
419+ result_array
420+ end
421+ end
422+ end
423+
424+ function default_filtered_comprehension_func (itr, lidx, body, condition)
425+ work_dist = _work_distribution_code ()
426+ quote
427+ let range = $ itr
428+ local thread_results = Vector {Vector{Any}} (undef, threadpoolsize ())
429+ # Initialize all result vectors to empty
430+ for i in 1 : threadpoolsize ()
431+ thread_results[i] = Any[]
432+ end
433+
434+ function threadsfor_fun (tid = 1 ; onethread = false )
435+ $ work_dist
436+ # run this thread's iterations with filtering
437+ local_results = Any[]
438+ for i = f: l
439+ local $ (esc (lidx)) = @inbounds r[i]
440+ if $ condition
441+ push! (local_results, $ body)
442+ end
443+ end
444+ thread_results[tid] = local_results
445+ end
446+
447+ result = thread_results # This will be populated by threading_run
448+ end
449+ end
450+ end
451+
283452"""
284453 Threads.@threads [schedule] for ... end
454+ Threads.@threads [schedule] [expr for ... end]
285455
286- A macro to execute a `for` loop in parallel. The iteration space is distributed to
456+ A macro to execute a `for` loop or array comprehension in parallel. The iteration space is distributed to
287457coarse-grained tasks. This policy can be specified by the `schedule` argument. The
288458execution of the loop waits for the evaluation of all iterations.
289459
460+ For `for` loops, the macro executes the loop body in parallel but does not return a value.
461+ For array comprehensions, the macro executes the comprehension in parallel and returns
462+ the collected results as an array.
463+
290464See also: [`@spawn`](@ref Threads.@spawn) and
291465`pmap` in [`Distributed`](@ref man-distributed).
292466
@@ -371,6 +545,8 @@ thread other than 1.
371545
372546## Examples
373547
548+ ### For loops
549+
374550To illustrate of the different scheduling strategies, consider the following function
375551`busywait` containing a non-yielding timed loop that runs for a given number of seconds.
376552
@@ -400,6 +576,38 @@ julia> @time begin
400576
401577The `:dynamic` example takes 2 seconds since one of the non-occupied threads is able
402578to run two of the 1-second iterations to complete the for loop.
579+
580+ ### Array comprehensions
581+
582+ The `@threads` macro also supports array comprehensions, which return the collected results:
583+
584+ ```julia-repl
585+ julia> Threads.@threads [i^2 for i in 1:5] # Simple comprehension
586+ 5-element Vector{Int64}:
587+ 1
588+ 4
589+ 9
590+ 16
591+ 25
592+
593+ julia> Threads.@threads [i^2 for i in 1:5 if iseven(i)] # Filtered comprehension
594+ 2-element Vector{Int64}:
595+ 4
596+ 16
597+ ```
598+
599+ When the iterator doesn't have a known length, such as a channel, the `:greedy` scheduling
600+ option can be used, but note that the order of the results is not guaranteed.
601+ ```julia-repl
602+ julia> c = Channel(5, spawn=true) do ch
603+ foreach(i -> put!(ch, i), 1:5)
604+ end;
605+
606+ julia> Threads.@threads :greedy [i^2 for i in c if iseven(i)]
607+ 2-element Vector{Any}:
608+ 16
609+ 4
610+ ```
403611"""
404612macro threads (args... )
405613 na = length (args)
@@ -420,13 +628,18 @@ macro threads(args...)
420628 else
421629 throw (ArgumentError (" wrong number of arguments in @threads" ))
422630 end
423- if ! (isa (ex, Expr) && ex. head === :for )
424- throw (ArgumentError (" @threads requires a `for` loop expression" ))
425- end
426- if ! (ex. args[1 ] isa Expr && ex. args[1 ]. head === :(= ))
427- throw (ArgumentError (" nested outer loops are not currently supported by @threads" ))
631+ if isa (ex, Expr) && ex. head === :comprehension
632+ # Handle array comprehensions
633+ return _threadsfor_comprehension (ex. args[1 ], sched)
634+ elseif isa (ex, Expr) && ex. head === :for
635+ # Handle for loops
636+ if ! (ex. args[1 ] isa Expr && ex. args[1 ]. head === :(= ))
637+ throw (ArgumentError (" nested outer loops are not currently supported by @threads" ))
638+ end
639+ return _threadsfor (ex. args[1 ], ex. args[2 ], sched)
640+ else
641+ throw (ArgumentError (" @threads requires a `for` loop or comprehension expression" ))
428642 end
429- return _threadsfor (ex. args[1 ], ex. args[2 ], sched)
430643end
431644
432645function _spawn_set_thrpool (t:: Task , tp:: Symbol )
0 commit comments