diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index c75cfb4e..128e5c7d 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -58,6 +58,13 @@ end @test y[3,:] isa CuArray end +@testset "restructure gpu" begin + dudt = Dense(1,1) |> gpu + p,re = Flux.destructure(dudt) + foo(x) = sum(re(p)(x)) + @test gradient(foo, cu(rand(1)))[1] isa CuArray +end + if CuArrays.has_cudnn() @info "Testing Flux/CUDNN" include("cudnn.jl")