add comment on broken layers
This commit is contained in:
parent
c4409fa6d1
commit
0801064d50
|
@ -10,11 +10,13 @@
|
||||||
@test gradient(x -> sum(cpu(x)), gpu(rand(3,3))) isa Tuple
|
@test gradient(x -> sum(cpu(x)), gpu(rand(3,3))) isa Tuple
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# TODO: These layers get into scalar indexing
|
||||||
|
# `AlphaDropout` throws a compilation error on GPUs,
|
||||||
|
# whereas, the rest are scalar indexing issues.
|
||||||
const BROKEN_LAYERS = [DepthwiseConv,
|
const BROKEN_LAYERS = [DepthwiseConv,
|
||||||
AlphaDropout,]
|
AlphaDropout,
|
||||||
|
InstanceNorm,
|
||||||
const THROWING_LAYERS = [InstanceNorm,
|
GroupNorm]
|
||||||
GroupNorm]
|
|
||||||
|
|
||||||
function gradtest(name::String, layers::Vector, xs = nothing, args...)
|
function gradtest(name::String, layers::Vector, xs = nothing, args...)
|
||||||
isnothing(xs) && error("Missing input to test the layers against.")
|
isnothing(xs) && error("Missing input to test the layers against.")
|
||||||
|
@ -26,9 +28,6 @@ function gradtest(name::String, layers::Vector, xs = nothing, args...)
|
||||||
if any(x -> isa(l, x), BROKEN_LAYERS)
|
if any(x -> isa(l, x), BROKEN_LAYERS)
|
||||||
ps = Flux.params(l)
|
ps = Flux.params(l)
|
||||||
@test_broken gradient(() -> sum(l(xs)), ps) isa Flux.Zygote.Grads
|
@test_broken gradient(() -> sum(l(xs)), ps) isa Flux.Zygote.Grads
|
||||||
elseif any(x -> isa(l, x), THROWING_LAYERS)
|
|
||||||
ps = Flux.params(l)
|
|
||||||
@test_throws ErrorException gradient(() -> sum(l(xs)), ps) isa Flux.Zygote.Grads
|
|
||||||
else
|
else
|
||||||
ps = Flux.params(l)
|
ps = Flux.params(l)
|
||||||
@test gradient(() -> sum(l(xs)), ps) isa Flux.Zygote.Grads
|
@test gradient(() -> sum(l(xs)), ps) isa Flux.Zygote.Grads
|
||||||
|
|
Loading…
Reference in New Issue