formatting fixes
This commit is contained in:
parent
717ad9328d
commit
76dc8ea9d4
|
@ -9,23 +9,23 @@
|
|||
@test gradient(x -> sum(cpu(x)), gpu(rand(3,3))) isa Tuple
|
||||
|
||||
function gradtest(layers, args...; name = "Conv", xs = rand(Float32, 28, 28, 1, 1))
|
||||
@testset "$name GPU grad tests" begin
|
||||
for layer in layers
|
||||
@testset "$layer GPU grad test" begin
|
||||
l = gpu(layer(args...))
|
||||
xs = gpu(xs)
|
||||
if l isa DepthwiseConv
|
||||
@test_broken gradient(Flux.params(l)) do
|
||||
sum(l(xs))
|
||||
end isa Flux.Zygote.Grads
|
||||
else
|
||||
@test gradient(Flux.params(l)) do
|
||||
sum(l(xs))
|
||||
end isa Flux.Zygote.Grads
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
@testset "$name GPU grad tests" begin
|
||||
for layer in layers
|
||||
@testset "$layer GPU grad test" begin
|
||||
l = gpu(layer(args...))
|
||||
xs = gpu(xs)
|
||||
if l isa DepthwiseConv
|
||||
@test_broken gradient(Flux.params(l)) do
|
||||
sum(l(xs))
|
||||
end isa Flux.Zygote.Grads
|
||||
else
|
||||
@test gradient(Flux.params(l)) do
|
||||
sum(l(xs))
|
||||
end isa Flux.Zygote.Grads
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
# Repeats from Conv, CrossCor
|
||||
|
@ -49,38 +49,38 @@ groupnorm = [GroupNorm]
|
|||
gradtest(groupnorm, 3, 1, name = "GroupNorm", xs = rand(Float32, 28,28,3,1))
|
||||
|
||||
const stateless_layers = [Flux.mse,
|
||||
Flux.crossentropy,
|
||||
Flux.logitcrossentropy,]
|
||||
Flux.normalise]
|
||||
Flux.crossentropy,
|
||||
Flux.logitcrossentropy,]
|
||||
Flux.normalise]
|
||||
|
||||
const stateless_layers_broadcasted = [Flux.binarycrossentropy,
|
||||
Flux.logitbinarycrossentropy]
|
||||
Flux.logitbinarycrossentropy]
|
||||
|
||||
function stateless_gradtest(f, args...)
|
||||
@test gradient((args...) -> sum(f(args...)), args...)[1] isa CuArray
|
||||
@test gradient((args...) -> sum(f(args...)), args...)[1] isa CuArray
|
||||
end
|
||||
|
||||
function stateless_gradtest_broadcasted(f, args...)
|
||||
if f == Flux.binarycrossentropy
|
||||
@test_broken gradient((args...) -> sum(f.(args...)), args...)[1] isa CuArray
|
||||
else
|
||||
@test gradient((args...) -> sum(f.(args...)), args...)[1] isa CuArray
|
||||
end
|
||||
if f == Flux.binarycrossentropy
|
||||
@test_broken gradient((args...) -> sum(f.(args...)), args...)[1] isa CuArray
|
||||
else
|
||||
@test gradient((args...) -> sum(f.(args...)), args...)[1] isa CuArray
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Stateless GPU grad tests" begin
|
||||
x = gpu(rand(3,3))
|
||||
y = gpu(rand(3,3))
|
||||
x = gpu(rand(3,3))
|
||||
y = gpu(rand(3,3))
|
||||
|
||||
for layer in stateless_layers
|
||||
if layer == Flux.normalise
|
||||
stateless_gradtest(layer, x)
|
||||
else
|
||||
stateless_gradtest(layer, x, y)
|
||||
end
|
||||
end
|
||||
for layer in stateless_layers
|
||||
if layer == Flux.normalise
|
||||
stateless_gradtest(layer, x)
|
||||
else
|
||||
stateless_gradtest(layer, x, y)
|
||||
end
|
||||
end
|
||||
|
||||
for layer in stateless_layers_broadcasted
|
||||
stateless_gradtest_broadcasted(layer, x, y)
|
||||
end
|
||||
for layer in stateless_layers_broadcasted
|
||||
stateless_gradtest_broadcasted(layer, x, y)
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue