unrolling test
This commit is contained in:
parent
7e983c74cb
commit
5111db4037
|
@ -122,7 +122,7 @@ unseq(graph) = unseqout(unseqin(graph))
|
|||
|
||||
function unroll1(model)
|
||||
graph, state = unrollgraph(model, 1)
|
||||
Stateful(Capacitor(graph), state)
|
||||
Stateful(Capacitor(unseq(graph)), state)
|
||||
end
|
||||
|
||||
flip(model) = Capacitor(map(x -> x isa Offset ? -x : x, atomise(model)))
|
||||
|
|
|
@ -140,6 +140,9 @@ function (m::Stateful)(x)
|
|||
end
|
||||
end
|
||||
|
||||
stateless(m) = m
|
||||
stateless(m::Stateful) = m.model
|
||||
|
||||
struct SeqModel
|
||||
model
|
||||
steps::Int
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
using Flux: stateless
|
||||
|
||||
function apply(model, xs, state)
|
||||
ys = similar(xs, 0)
|
||||
for x in xs
|
||||
state, y = model(state, x)
|
||||
push!(ys, y)
|
||||
end
|
||||
state, ys
|
||||
end
|
||||
|
||||
@testset "RNN unrolling" begin
|
||||
r = Recurrent(10, 5)
|
||||
xs = [rand(10) for _ = 1:3]
|
||||
_, ys = apply(stateless(unroll1(r)), xs, (squeeze(r.y.x, 1),))
|
||||
ru = unroll(r, 3)
|
||||
@test ru(Seq(xs)) == ys
|
||||
end
|
|
@ -15,5 +15,6 @@ end
|
|||
|
||||
include("batching.jl")
|
||||
include("basic.jl")
|
||||
include("recurrent.jl")
|
||||
@tfonly include("backend/tensorflow.jl")
|
||||
@mxonly include("backend/mxnet.jl")
|
||||
|
|
Loading…
Reference in New Issue