test backward pass

This commit is contained in:
Mike J Innes 2017-02-23 22:51:37 +00:00
parent 1f1b1ee5b2
commit 777ecd2bc4
1 changed files with 14 additions and 1 deletions

View File

@ -9,14 +9,27 @@ d = Affine(20, 10)
dm = mxnet(d, (20, 1))
@test d(xs) dm(xs)
@testset "Backward Pass" begin
d = deepcopy(d)
@test dm(xs) d(xs)
@test dm(xs) d(xs)
Δ = back!(dm, randn(10), xs)
@test length(Δ) == 20
update!(dm, 0.1)
@test dm(xs) d(xs)
@test dm(xs) d(xs)
end
@testset "FeedForward interface" begin
# TODO: test run
f = mx.FeedForward(Chain(d, softmax))
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
flatten, Affine(1587, 10), softmax)
f = mx.FeedForward(m)
# TODO: test run
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
end