clearing failures
This commit is contained in:
parent
26631e1361
commit
c4409fa6d1
|
@ -5,8 +5,16 @@
|
|||
# Check that getting the gradients does not throw
|
||||
|
||||
# generic movement tests
|
||||
@test_broken gradient(x -> sum(gpu(x)), rand(3,3)) isa Tuple
|
||||
@test_throws ErrorException gradient(x -> sum(cpu(x)), gpu(rand(3,3))) isa Tuple
|
||||
@testset "Basic GPU Movement" begin
|
||||
@test gradient(x -> sum(gpu(x)), rand(3,3)) isa Tuple
|
||||
@test gradient(x -> sum(cpu(x)), gpu(rand(3,3))) isa Tuple
|
||||
end
|
||||
|
||||
const BROKEN_LAYERS = [DepthwiseConv,
|
||||
AlphaDropout,]
|
||||
|
||||
const THROWING_LAYERS = [InstanceNorm,
|
||||
GroupNorm]
|
||||
|
||||
function gradtest(name::String, layers::Vector, xs = nothing, args...)
|
||||
isnothing(xs) && error("Missing input to test the layers against.")
|
||||
|
@ -15,9 +23,12 @@ function gradtest(name::String, layers::Vector, xs = nothing, args...)
|
|||
@testset "$layer GPU grad test" begin
|
||||
l = gpu(layer(args...))
|
||||
xs = gpu(xs)
|
||||
if l isa DepthwiseConv || l isa AlphaDropout
|
||||
if any(x -> isa(l, x), BROKEN_LAYERS)
|
||||
ps = Flux.params(l)
|
||||
@test_broken gradient(() -> sum(l(xs)), ps) isa Flux.Zygote.Grads
|
||||
elseif any(x -> isa(l, x), THROWING_LAYERS)
|
||||
ps = Flux.params(l)
|
||||
@test_throws ErrorException gradient(() -> sum(l(xs)), ps) isa Flux.Zygote.Grads
|
||||
else
|
||||
ps = Flux.params(l)
|
||||
@test gradient(() -> sum(l(xs)), ps) isa Flux.Zygote.Grads
|
||||
|
@ -46,10 +57,10 @@ dropout_layers = [Dropout, AlphaDropout]
|
|||
gradtest("Dropout", dropout_layers, r, 0.5f0)
|
||||
|
||||
norm_layers = [LayerNorm, BatchNorm]
|
||||
gradtest("Normalising" norm_layers, rand(Float32, 28,28,3,1), 1)
|
||||
gradtest("Normalising", norm_layers, rand(Float32, 28,28,3,1), 1)
|
||||
|
||||
instancenorm = [InstanceNorm]
|
||||
gradtest(instancenorm, r, "InstanceNorm", 1)
|
||||
gradtest("InstanceNorm", instancenorm, r, 1)
|
||||
|
||||
groupnorm = [GroupNorm]
|
||||
gradtest("GroupNorm", groupnorm, rand(Float32, 28,28,3,1), 3, 1)
|
||||
|
@ -67,11 +78,7 @@ function stateless_gradtest(f, args...)
|
|||
end
|
||||
|
||||
function stateless_gradtest_broadcasted(f, args...)
|
||||
if f == Flux.binarycrossentropy
|
||||
@test_broken gradient((args...) -> sum(f.(args...)), args...)[1] isa CuArray
|
||||
else
|
||||
@test gradient((args...) -> sum(f.(args...)), args...)[1] isa CuArray
|
||||
end
|
||||
@test gradient((args...) -> sum(f.(args...)), args...)[1] isa CuArray
|
||||
end
|
||||
|
||||
@testset "Stateless GPU grad tests" begin
|
||||
|
|
Loading…
Reference in New Issue