test restructure on the GPU
Requires https://github.com/FluxML/Zygote.jl/pull/474
This commit is contained in:
parent
d1edd9b16d
commit
9803826a36
|
@ -58,6 +58,13 @@ end
|
|||
@test y[3,:] isa CuArray
|
||||
end
|
||||
|
||||
@testset "restructure gpu" begin
|
||||
dudt = Dense(1,1) |> gpu
|
||||
p,re = Flux.destructure(dudt)
|
||||
foo(x) = sum(re(p)(x))
|
||||
@test gradient(foo, cu(rand(1)))[1] isa CuArray
|
||||
end
|
||||
|
||||
if CuArrays.has_cudnn()
|
||||
@info "Testing Flux/CUDNN"
|
||||
include("cudnn.jl")
|
||||
|
|
Loading…
Reference in New Issue