recurrent models in tf

This commit is contained in:
Mike J Innes 2017-05-01 18:05:17 +01:00
parent b511160ec4
commit 796d7d7e99
2 changed files with 14 additions and 0 deletions

View File

@ -44,3 +44,10 @@ for f in :[back!, update!].args
error($(string(f)) * " is not yet supported on TensorFlow models")
end
end
# Recurrent Models
using Flux: Stateful, SeqModel
tf(m::Stateful) = Stateful(tf(m.model), m.istate, m.ostate)
tf(m::SeqModel) = SeqModel(tf(m.model), m.steps)

View File

@ -14,6 +14,13 @@ dt = tf(d)
@test tf(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18]
end
@testset "Recurrence" begin
seq = batchone(Seq(rand(10) for i = 1:3))
r = unroll(Recurrent(10, 5), 3)
rm = tf(r)
@test r(seq) rm(seq)
end
@testset "Tensor interface" begin
sess = TensorFlow.Session()
X = placeholder(Float32)