998: test restructure on the GPU r=CarloLucibello a=ChrisRackauckas

Requires https://github.com/FluxML/Zygote.jl/pull/474 to pass

Co-authored-by: Chris Rackauckas <accounts@chrisrackauckas.com>
This commit is contained in:
bors[bot] 2020-02-29 09:08:11 +00:00 committed by GitHub
commit 2dd23574c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 0 deletions

View File

@ -58,6 +58,13 @@ end
@test y[3,:] isa CuArray
end
@testset "restructure gpu" begin
dudt = Dense(1,1) |> gpu
p,re = Flux.destructure(dudt)
foo(x) = sum(re(p)(x))
@test gradient(foo, cu(rand(1)))[1] isa CuArray
end
if CuArrays.has_cudnn()
@info "Testing Flux/CUDNN"
include("cudnn.jl")