unrolling test

This commit is contained in:
Mike J Innes 2017-03-29 18:30:28 +01:00
parent 7e983c74cb
commit 5111db4037
4 changed files with 23 additions and 1 deletions

View File

@ -122,7 +122,7 @@ unseq(graph) = unseqout(unseqin(graph))
function unroll1(model)
graph, state = unrollgraph(model, 1)
Stateful(Capacitor(graph), state)
Stateful(Capacitor(unseq(graph)), state)
end
flip(model) = Capacitor(map(x -> x isa Offset ? -x : x, atomise(model)))

View File

@ -140,6 +140,9 @@ function (m::Stateful)(x)
end
end
stateless(m) = m
stateless(m::Stateful) = m.model
struct SeqModel
model
steps::Int

18
test/recurrent.jl Normal file
View File

@ -0,0 +1,18 @@
using Flux: stateless
function apply(model, xs, state)
ys = similar(xs, 0)
for x in xs
state, y = model(state, x)
push!(ys, y)
end
state, ys
end
@testset "RNN unrolling" begin
r = Recurrent(10, 5)
xs = [rand(10) for _ = 1:3]
_, ys = apply(stateless(unroll1(r)), xs, (squeeze(r.y.x, 1),))
ru = unroll(r, 3)
@test ru(Seq(xs)) == ys
end

View File

@ -15,5 +15,6 @@ end
include("batching.jl")
include("basic.jl")
include("recurrent.jl")
@tfonly include("backend/tensorflow.jl")
@mxonly include("backend/mxnet.jl")