@@ -86,11 +86,8 @@ function transform_gpu!(def, constargs, force_inbounds, unsafe_indices)
8686end
8787
8888struct WorkgroupLoop
89- indices:: Vector{Any}
9089 stmts:: Vector{Any}
9190 allocations:: Vector{Any}
92- private_allocations:: Vector{Any}
93- private:: Set{Symbol}
9491 terminated_in_sync:: Bool
9592end
9693
@@ -111,26 +108,18 @@ function find_sync(stmt)
111108end
112109
113110# TODO proper handling of LineInfo
114- function split (
115- stmts,
116- indices = Any[], private = Set {Symbol} (),
117- )
111+ function split (stmts)
118112 # 1. Split the code into blocks separated by `@synchronize`
119- # 2. Aggregate `@index` expressions
120- # 3. Hoist allocations
121- # 4. Hoist uniforms
122113
123114 current = Any[]
124115 allocations = Any[]
125- private_allocations = Any[]
126116 new_stmts = Any[]
127117 for stmt in stmts
128118 has_sync = find_sync (stmt)
129119 if has_sync
130- loop = WorkgroupLoop (deepcopy (indices), current, allocations, private_allocations, deepcopy (private) , is_sync (stmt))
120+ loop = WorkgroupLoop (current, allocations, is_sync (stmt))
131121 push! (new_stmts, emit (loop))
132122 allocations = Any[]
133- private_allocations = Any[]
134123 current = Any[]
135124 is_sync (stmt) && continue
136125
@@ -142,7 +131,7 @@ function split(
142131 function recurse (expr:: Expr )
143132 expr = unblock (expr)
144133 if is_scope_construct (expr) && any (find_sync, expr. args)
145- new_args = unblock (split (expr. args, deepcopy (indices), deepcopy (private) ))
134+ new_args = unblock (split (expr. args))
146135 return Expr (expr. head, new_args... )
147136 else
148137 return Expr (expr. head, map (recurse, expr. args)... )
@@ -156,14 +145,10 @@ function split(
156145 push! (allocations, stmt)
157146 continue
158147 elseif @capture (stmt, @private lhs_ = rhs_)
159- push! (private, lhs)
160- push! (private_allocations, :($ lhs = $ rhs))
148+ push! (allocations, :($ lhs = $ rhs))
161149 continue
162150 elseif @capture (stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
163- if @capture (rhs, @index (args__))
164- push! (indices, stmt)
165- continue
166- elseif @capture (rhs, @localmem (args__) | @uniform (args__))
151+ if @capture (rhs, @localmem (args__) | @uniform (args__))
167152 push! (allocations, stmt)
168153 continue
169154 elseif @capture (rhs, @private (T_, dims_))
@@ -175,7 +160,6 @@ function split(
175160 end
176161 alloc = :($ Scratchpad (__ctx__, $ T, Val ($ dims)))
177162 push! (allocations, :($ lhs = $ alloc))
178- push! (private, lhs)
179163 continue
180164 end
181165 end
@@ -185,7 +169,7 @@ function split(
185169
186170 # everything since the last `@synchronize`
187171 if ! isempty (current)
188- loop = WorkgroupLoop (deepcopy (indices), current, allocations, private_allocations, deepcopy (private) , false )
172+ loop = WorkgroupLoop (current, allocations, false )
189173 push! (new_stmts, emit (loop))
190174 end
191175 return new_stmts
@@ -197,9 +181,7 @@ function emit(loop)
197181 body = Expr (:block , loop. stmts... )
198182 loopexpr = quote
199183 $ (loop. allocations... )
200- $ (loop. private_allocations... )
201184 if __active_lane__
202- $ (loop. indices... )
203185 $ (unblock (body))
204186 end
205187 end
0 commit comments