Flux.jl/test/recurrent.jl

18 lines
420 B
Julia
Raw Normal View History

2017-03-29 17:30:28 +00:00
function apply(model, xs, state)
ys = similar(xs, 0)
for x in xs
state, y = model(state, x)
push!(ys, y)
end
state, ys
end
@testset "RNN unrolling" begin
r = Recurrent(10, 5)
2017-04-18 20:04:21 +00:00
xs = [rand(1, 10) for _ = 1:3]
2017-04-27 16:27:46 +00:00
_, ys = apply(unroll1(r).model, xs, (r.y.x,))
2017-04-18 20:04:21 +00:00
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
2017-03-29 17:30:28 +00:00
ru = unroll(r, 3)
2017-04-18 20:04:21 +00:00
ru(batchone(Seq(squeeze.(xs))))[1] == squeeze.(ys)
2017-03-29 17:30:28 +00:00
end