diff --git a/src/backend/tensorflow/model.jl b/src/backend/tensorflow/model.jl index b43ce003..62fe16a8 100644 --- a/src/backend/tensorflow/model.jl +++ b/src/backend/tensorflow/model.jl @@ -44,3 +44,10 @@ for f in :[back!, update!].args error($(string(f)) * " is not yet supported on TensorFlow models") end end + +# Recurrent Models + +using Flux: Stateful, SeqModel + +tf(m::Stateful) = Stateful(tf(m.model), m.istate, m.ostate) +tf(m::SeqModel) = SeqModel(tf(m.model), m.steps) diff --git a/test/backend/tensorflow.jl b/test/backend/tensorflow.jl index f0f30cd0..35a97d7d 100644 --- a/test/backend/tensorflow.jl +++ b/test/backend/tensorflow.jl @@ -14,6 +14,13 @@ dt = tf(d) @test tf(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18] end +@testset "Recurrence" begin + seq = batchone(Seq(rand(10) for i = 1:3)) + r = unroll(Recurrent(10, 5), 3) + rm = tf(r) + @test r(seq) ≈ rm(seq) +end + @testset "Tensor interface" begin sess = TensorFlow.Session() X = placeholder(Float32)