The issue has been reported in
FluxML/Flux.jl#2513
and is due to these specializations
https://github.com/LuxDL/MLDataDevices.jl/blob/17419d27888e3a48b52f318249a0f037524f0f1e/src/public.jl#L341
The simplest solution would be to remove the specializations and let fmap handle everything.
Possible optimized implementations could be upstreamed to Functors.jl, but we really want to keep track of object identity here.