test backward pass
This commit is contained in:
parent
1f1b1ee5b2
commit
777ecd2bc4
|
@ -9,14 +9,27 @@ d = Affine(20, 10)
|
|||
dm = mxnet(d, (20, 1))
|
||||
@test d(xs) ≈ dm(xs)
|
||||
|
||||
@testset "Backward Pass" begin
|
||||
d′ = deepcopy(d)
|
||||
@test dm(xs) ≈ d(xs)
|
||||
@test dm(xs) ≈ d′(xs)
|
||||
|
||||
Δ = back!(dm, randn(10), xs)
|
||||
@test length(Δ) == 20
|
||||
update!(dm, 0.1)
|
||||
|
||||
@test dm(xs) ≈ d(xs)
|
||||
@test dm(xs) ≉ d′(xs)
|
||||
end
|
||||
|
||||
@testset "FeedForward interface" begin
|
||||
# TODO: test run
|
||||
f = mx.FeedForward(Chain(d, softmax))
|
||||
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
|
||||
|
||||
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
|
||||
flatten, Affine(1587, 10), softmax)
|
||||
f = mx.FeedForward(m)
|
||||
# TODO: test run
|
||||
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue