unrolling test
This commit is contained in:
parent
7e983c74cb
commit
5111db4037
@ -122,7 +122,7 @@ unseq(graph) = unseqout(unseqin(graph))
|
|||||||
|
|
||||||
function unroll1(model)
|
function unroll1(model)
|
||||||
graph, state = unrollgraph(model, 1)
|
graph, state = unrollgraph(model, 1)
|
||||||
Stateful(Capacitor(graph), state)
|
Stateful(Capacitor(unseq(graph)), state)
|
||||||
end
|
end
|
||||||
|
|
||||||
flip(model) = Capacitor(map(x -> x isa Offset ? -x : x, atomise(model)))
|
flip(model) = Capacitor(map(x -> x isa Offset ? -x : x, atomise(model)))
|
||||||
|
@ -140,6 +140,9 @@ function (m::Stateful)(x)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
stateless(m) = m
|
||||||
|
stateless(m::Stateful) = m.model
|
||||||
|
|
||||||
struct SeqModel
|
struct SeqModel
|
||||||
model
|
model
|
||||||
steps::Int
|
steps::Int
|
||||||
|
18
test/recurrent.jl
Normal file
18
test/recurrent.jl
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
using Flux: stateless
|
||||||
|
|
||||||
|
function apply(model, xs, state)
|
||||||
|
ys = similar(xs, 0)
|
||||||
|
for x in xs
|
||||||
|
state, y = model(state, x)
|
||||||
|
push!(ys, y)
|
||||||
|
end
|
||||||
|
state, ys
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "RNN unrolling" begin
|
||||||
|
r = Recurrent(10, 5)
|
||||||
|
xs = [rand(10) for _ = 1:3]
|
||||||
|
_, ys = apply(stateless(unroll1(r)), xs, (squeeze(r.y.x, 1),))
|
||||||
|
ru = unroll(r, 3)
|
||||||
|
@test ru(Seq(xs)) == ys
|
||||||
|
end
|
@ -15,5 +15,6 @@ end
|
|||||||
|
|
||||||
include("batching.jl")
|
include("batching.jl")
|
||||||
include("basic.jl")
|
include("basic.jl")
|
||||||
|
include("recurrent.jl")
|
||||||
@tfonly include("backend/tensorflow.jl")
|
@tfonly include("backend/tensorflow.jl")
|
||||||
@mxonly include("backend/mxnet.jl")
|
@mxonly include("backend/mxnet.jl")
|
||||||
|
Loading…
Reference in New Issue
Block a user