@@ -6,10 +6,6 @@ using Functors
66using FiniteDifferences
77using CUDA
88
9- Enzyme. API. typeWarning! (false ) # suppresses a warning with Bilinear https://github.com/EnzymeAD/Enzyme.jl/issues/1341
10- Enzyme. API. runtimeActivity! (true ) # for Enzyme debugging
11- # Enzyme.Compiler.bitcode_replacement!(false)
12-
139_make_zero (x:: Union{Number,AbstractArray} ) = zero (x)
1410_make_zero (x) = x
1511make_zero (model) = fmap (_make_zero, model)
121117 (SkipConnection (Dense (2 => 2 ), vcat), randn (Float32, 2 , 3 ), " SkipConnection" ),
122118 (Flux. Bilinear ((2 , 2 ) => 3 ), randn (Float32, 2 , 1 ), " Bilinear" ),
123119 (GRU (3 => 5 ), randn (Float32, 3 , 10 ), " GRU" ),
120+ (ConvTranspose ((3 , 3 ), 3 => 2 , stride= 2 ), rand (Float32, 5 , 5 , 3 , 1 ), " ConvTranspose" ),
124121 ]
125122
126123 for (model, x, name) in models_xs
155152 end
156153 end
157154end
158-
159- @testset " Broken Models" begin
160- function loss (model, x)
161- Flux. reset! (model)
162- sum (model (x))
163- end
164-
165- device = Flux. get_device ()
166-
167- models_xs = [
168- # Pending https://github.com/FluxML/NNlib.jl/issues/565
169- (ConvTranspose ((3 , 3 ), 3 => 2 , stride= 2 ), rand (Float32, 5 , 5 , 3 , 1 ), " ConvTranspose" ),
170- ]
171-
172- for (model, x, name) in models_xs
173- @testset " check grad $name " begin
174- println (" testing $name " )
175- broken = false
176- try
177- test_enzyme_grad (loss, model, x)
178- catch e
179- println (e)
180- broken = true
181- end
182- @test broken
183- end
184- end
185- end
186-
0 commit comments