Merge #998
998: test restructure on the GPU r=CarloLucibello a=ChrisRackauckas Requires https://github.com/FluxML/Zygote.jl/pull/474 to pass Co-authored-by: Chris Rackauckas <accounts@chrisrackauckas.com>
This commit is contained in:
commit
2dd23574c0
|
@ -58,6 +58,13 @@ end
|
||||||
@test y[3,:] isa CuArray
|
@test y[3,:] isa CuArray
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "restructure gpu" begin
|
||||||
|
dudt = Dense(1,1) |> gpu
|
||||||
|
p,re = Flux.destructure(dudt)
|
||||||
|
foo(x) = sum(re(p)(x))
|
||||||
|
@test gradient(foo, cu(rand(1)))[1] isa CuArray
|
||||||
|
end
|
||||||
|
|
||||||
if CuArrays.has_cudnn()
|
if CuArrays.has_cudnn()
|
||||||
@info "Testing Flux/CUDNN"
|
@info "Testing Flux/CUDNN"
|
||||||
include("cudnn.jl")
|
include("cudnn.jl")
|
||||||
|
|
Loading…
Reference in New Issue