add comment on broken layers
This commit is contained in:
parent
c4409fa6d1
commit
0801064d50
|
@ -10,11 +10,13 @@
|
|||
@test gradient(x -> sum(cpu(x)), gpu(rand(3,3))) isa Tuple
|
||||
end
|
||||
|
||||
# TODO: These layers get into scalar indexing
|
||||
# `AlphaDropout` throws a compilation error on GPUs,
|
||||
# whereas, the rest are scalar indexing issues.
|
||||
const BROKEN_LAYERS = [DepthwiseConv,
|
||||
AlphaDropout,]
|
||||
|
||||
const THROWING_LAYERS = [InstanceNorm,
|
||||
GroupNorm]
|
||||
AlphaDropout,
|
||||
InstanceNorm,
|
||||
GroupNorm]
|
||||
|
||||
function gradtest(name::String, layers::Vector, xs = nothing, args...)
|
||||
isnothing(xs) && error("Missing input to test the layers against.")
|
||||
|
@ -26,9 +28,6 @@ function gradtest(name::String, layers::Vector, xs = nothing, args...)
|
|||
if any(x -> isa(l, x), BROKEN_LAYERS)
|
||||
ps = Flux.params(l)
|
||||
@test_broken gradient(() -> sum(l(xs)), ps) isa Flux.Zygote.Grads
|
||||
elseif any(x -> isa(l, x), THROWING_LAYERS)
|
||||
ps = Flux.params(l)
|
||||
@test_throws ErrorException gradient(() -> sum(l(xs)), ps) isa Flux.Zygote.Grads
|
||||
else
|
||||
ps = Flux.params(l)
|
||||
@test gradient(() -> sum(l(xs)), ps) isa Flux.Zygote.Grads
|
||||
|
|
Loading…
Reference in New Issue