Add tests
This commit is contained in:
parent
33a7f545b7
commit
52a50b2727
@ -1,6 +1,6 @@
|
|||||||
using Flux.Tracker, Base.Test, NNlib
|
using Flux.Tracker, Base.Test, NNlib
|
||||||
using Flux.Tracker: TrackedReal, gradcheck
|
using Flux.Tracker: TrackedReal, gradcheck
|
||||||
using NNlib: conv
|
using NNlib: conv, depthwiseconv
|
||||||
|
|
||||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||||
@ -169,6 +169,8 @@ end
|
|||||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
@test gradtest(conv, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
||||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 3, 2))
|
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 3, 2))
|
||||||
|
|
||||||
|
@test gradtest(depthwiseconv, rand(10,10,3,2), randn(2, 2, 1,3))
|
||||||
|
|
||||||
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||||
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user