test restructure on the GPU

Requires https://github.com/FluxML/Zygote.jl/pull/474
This commit is contained in:
Chris Rackauckas 2020-01-20 13:53:28 -05:00
parent d1edd9b16d
commit 9803826a36
1 changed files with 7 additions and 0 deletions

View File

@ -58,6 +58,13 @@ end
@test y[3,:] isa CuArray
end
@testset "restructure gpu" begin
dudt = Dense(1,1) |> gpu
p,re = Flux.destructure(dudt)
foo(x) = sum(re(p)(x))
@test gradient(foo, cu(rand(1)))[1] isa CuArray
end
if CuArrays.has_cudnn()
@info "Testing Flux/CUDNN"
include("cudnn.jl")