1- # https://github.com/EnzymeAD/Enzyme.jl/issues/1516
2- # On the CPU `autodiff_deferred` can deadlock.
3- # Hence a specialized CPU version
4- function cpu_fwd (ctx, f, args... )
5- EnzymeCore. autodiff (Forward, Const (f), Const{Nothing}, Const (ctx), args... )
6- return nothing
7- end
8-
9- function gpu_fwd (ctx, f, args... )
1+ function fwd (ctx, f, args... )
102 EnzymeCore. autodiff_deferred (Forward, Const (f), Const{Nothing}, Const (ctx), args... )
113 return nothing
124end
135
146function EnzymeRules. forward (
15- func:: Const{<:Kernel{CPU}} ,
16- :: Type{Const{Nothing}} ,
17- args... ;
18- ndrange = nothing ,
19- workgroupsize = nothing ,
20- )
21- kernel = func. val
22- f = kernel. f
23- fwd_kernel = similar (kernel, cpu_fwd)
24-
25- return fwd_kernel (f, args... ; ndrange, workgroupsize)
26- end
27-
28- function EnzymeRules. forward (
29- func:: Const{<:Kernel{<:GPU}} ,
7+ func:: Const{<:Kernel} ,
308 :: Type{Const{Nothing}} ,
319 args... ;
3210 ndrange = nothing ,
3311 workgroupsize = nothing ,
3412 )
3513 kernel = func. val
3614 f = kernel. f
37- fwd_kernel = similar (kernel, gpu_fwd )
15+ fwd_kernel = similar (kernel, fwd )
3816
3917 return fwd_kernel (f, args... ; ndrange, workgroupsize)
4018end
4119
42- _enzyme_mkcontext (kernel:: Kernel{CPU} , ndrange, iterspace, dynamic) =
43- mkcontext (kernel, first (blocks (iterspace)), ndrange, iterspace, dynamic)
44- _enzyme_mkcontext (kernel:: Kernel{<:GPU} , ndrange, iterspace, dynamic) =
20+ _enzyme_mkcontext (kernel:: Kernel , ndrange, iterspace, dynamic) =
4521 mkcontext (kernel, ndrange, iterspace)
4622
47- _augmented_return (:: Kernel{CPU} , subtape, arg_refs, tape_type) =
48- AugmentedReturn {Nothing, Nothing, Tuple{Array, typeof(arg_refs), typeof(tape_type)}} (
49- nothing ,
50- nothing ,
51- (subtape, arg_refs, tape_type),
52- )
53- _augmented_return (:: Kernel{<:GPU} , subtape, arg_refs, tape_type) =
23+ _augmented_return (:: Kernel , subtape, arg_refs, tape_type) =
5424 AugmentedReturn {Nothing, Nothing, Any} (nothing , nothing , (subtape, arg_refs, tape_type))
5525
5626function _create_tape_kernel (
57- kernel:: Kernel{CPU} ,
58- ModifiedBetween,
59- FT,
60- ctxTy,
61- ndrange,
62- iterspace,
63- args2... ,
64- )
65- TapeType = EnzymeCore. tape_type (
66- ReverseSplitModified (ReverseSplitWithPrimal, ModifiedBetween),
67- FT,
68- Const{Nothing},
69- Const{ctxTy},
70- map (Core. Typeof, args2)... ,
71- )
72- subtape = Array {TapeType} (undef, size (blocks (iterspace)))
73- aug_kernel = similar (kernel, cpu_aug_fwd)
74- return TapeType, subtape, aug_kernel
75- end
76-
77- function _create_tape_kernel (
78- kernel:: Kernel{<:GPU} ,
27+ kernel:: Kernel ,
7928 ModifiedBetween,
8029 FT,
8130 ctxTy,
@@ -104,60 +53,11 @@ function _create_tape_kernel(
10453 # Allocate per thread
10554 subtape = allocate (backend (kernel), TapeType, prod (ndrange))
10655
107- aug_kernel = similar (kernel, gpu_aug_fwd )
56+ aug_kernel = similar (kernel, aug_fwd )
10857 return TapeType, subtape, aug_kernel
10958end
11059
111- _create_rev_kernel (kernel:: Kernel{CPU} ) = similar (kernel, cpu_rev)
112- _create_rev_kernel (kernel:: Kernel{<:GPU} ) = similar (kernel, gpu_rev)
113-
114- function cpu_aug_fwd (
115- ctx,
116- f:: FT ,
117- :: Val{ModifiedBetween} ,
118- subtape,
119- :: Val{TapeType} ,
120- args... ,
121- ) where {ModifiedBetween, FT, TapeType}
122- # A2 = Const{Nothing} -- since f->Nothing
123- forward, _ = EnzymeCore. autodiff_thunk (
124- ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)),
125- Const{Core. Typeof (f)},
126- Const{Nothing},
127- Const{Core. Typeof (ctx)},
128- map (Core. Typeof, args)... ,
129- )
130-
131- # On the CPU: F is a per block function
132- # On the CPU: subtape::Vector{Vector}
133- I = __index_Group_Cartesian (ctx, CartesianIndex (1 , 1 )) #= fake=#
134- subtape[I] = forward (Const (f), Const (ctx), args... )[1 ]
135- return nothing
136- end
137-
138- function cpu_rev (
139- ctx,
140- f:: FT ,
141- :: Val{ModifiedBetween} ,
142- subtape,
143- :: Val{TapeType} ,
144- args... ,
145- ) where {ModifiedBetween, FT, TapeType}
146- _, reverse = EnzymeCore. autodiff_thunk (
147- ReverseSplitModified (ReverseSplitWithPrimal, Val (ModifiedBetween)),
148- Const{Core. Typeof (f)},
149- Const{Nothing},
150- Const{Core. Typeof (ctx)},
151- map (Core. Typeof, args)... ,
152- )
153- I = __index_Group_Cartesian (ctx, CartesianIndex (1 , 1 )) #= fake=#
154- tp = subtape[I]
155- reverse (Const (f), Const (ctx), args... , tp)
156- return nothing
157- end
158-
159- # GPU support
160- function gpu_aug_fwd (
60+ function aug_fwd (
16161 ctx,
16262 f:: FT ,
16363 :: Val{ModifiedBetween} ,
@@ -184,7 +84,7 @@ function gpu_aug_fwd(
18484 return nothing
18585end
18686
187- function gpu_rev (
87+ function rev (
18888 ctx,
18989 f:: FT ,
19090 :: Val{ModifiedBetween} ,
@@ -232,11 +132,7 @@ function EnzymeRules.augmented_primal(
232132 arg_refs = ntuple (Val (N)) do i
233133 Base. @_inline_meta
234134 if args[i] isa Active
235- if func. val isa Kernel{<: GPU }
236- error (" Active kernel arguments not supported on GPU" )
237- else
238- Ref (EnzymeCore. make_zero (args[i]. val))
239- end
135+ error (" Active kernel arguments not supported" )
240136 else
241137 nothing
242138 end
@@ -292,7 +188,7 @@ function EnzymeRules.reverse(
292188
293189 ModifiedBetween = Val ((overwritten (config)[1 ], false , overwritten (config)[2 : end ]. .. ))
294190
295- rev_kernel = _create_rev_kernel (kernel)
191+ rev_kernel = similar (kernel, rev )
296192 rev_kernel (
297193 f,
298194 ModifiedBetween,
0 commit comments