test backward pass
This commit is contained in:
parent
1f1b1ee5b2
commit
777ecd2bc4
@ -9,14 +9,27 @@ d = Affine(20, 10)
|
|||||||
dm = mxnet(d, (20, 1))
|
dm = mxnet(d, (20, 1))
|
||||||
@test d(xs) ≈ dm(xs)
|
@test d(xs) ≈ dm(xs)
|
||||||
|
|
||||||
|
@testset "Backward Pass" begin
|
||||||
|
d′ = deepcopy(d)
|
||||||
|
@test dm(xs) ≈ d(xs)
|
||||||
|
@test dm(xs) ≈ d′(xs)
|
||||||
|
|
||||||
|
Δ = back!(dm, randn(10), xs)
|
||||||
|
@test length(Δ) == 20
|
||||||
|
update!(dm, 0.1)
|
||||||
|
|
||||||
|
@test dm(xs) ≈ d(xs)
|
||||||
|
@test dm(xs) ≉ d′(xs)
|
||||||
|
end
|
||||||
|
|
||||||
@testset "FeedForward interface" begin
|
@testset "FeedForward interface" begin
|
||||||
# TODO: test run
|
|
||||||
f = mx.FeedForward(Chain(d, softmax))
|
f = mx.FeedForward(Chain(d, softmax))
|
||||||
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
|
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
|
||||||
|
|
||||||
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
|
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
|
||||||
flatten, Affine(1587, 10), softmax)
|
flatten, Affine(1587, 10), softmax)
|
||||||
f = mx.FeedForward(m)
|
f = mx.FeedForward(m)
|
||||||
|
# TODO: test run
|
||||||
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user