recurrent models in tf
This commit is contained in:
parent
b511160ec4
commit
796d7d7e99
@ -44,3 +44,10 @@ for f in :[back!, update!].args
|
||||
error($(string(f)) * " is not yet supported on TensorFlow models")
|
||||
end
|
||||
end
|
||||
|
||||
# Recurrent Models
|
||||
|
||||
using Flux: Stateful, SeqModel
|
||||
|
||||
tf(m::Stateful) = Stateful(tf(m.model), m.istate, m.ostate)
|
||||
tf(m::SeqModel) = SeqModel(tf(m.model), m.steps)
|
||||
|
@ -14,6 +14,13 @@ dt = tf(d)
|
||||
@test tf(@net x -> x[1].*x[2])(([1,2,3],[4,5,6])) == [4,10,18]
|
||||
end
|
||||
|
||||
@testset "Recurrence" begin
|
||||
seq = batchone(Seq(rand(10) for i = 1:3))
|
||||
r = unroll(Recurrent(10, 5), 3)
|
||||
rm = tf(r)
|
||||
@test r(seq) ≈ rm(seq)
|
||||
end
|
||||
|
||||
@testset "Tensor interface" begin
|
||||
sess = TensorFlow.Session()
|
||||
X = placeholder(Float32)
|
||||
|
Loading…
Reference in New Issue
Block a user