From d7d95feab8c17dfe7334bbcf42a8fddc693d7238 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 2 Nov 2016 00:36:13 +0000 Subject: [PATCH] actually get GRU working --- src/backend/tensorflow/graph.jl | 4 ++++ src/layers/recurrent.jl | 36 ++++++++++++++++----------------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index a55e311f..6ca27a0b 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -2,14 +2,18 @@ import Base: @get! import DataFlow: Constant, postwalk, value, inputs, constant import TensorFlow: RawTensor +# TODO: implement Julia's type promotion rules + cvalue(x) = x cvalue(c::Constant) = c.value cvalue(v::Vertex) = cvalue(value(v)) graph(x::Tensor) = x +graph(x::Number) = TensorFlow.constant(Float32(x)) graph(::typeof(*), args...) = *(args...) graph(::typeof(.*), args...) = .*(args...) +graph(::typeof(.-), args...) = -(args...) graph(::typeof(+), args...) = +(args...) graph(::typeof(softmax), x) = nn.softmax(x) graph(::typeof(relu), x) = nn.relu(x) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index f8fb8943..ae192674 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,4 +1,4 @@ -export Recurrent, LSTM +export Recurrent, GatedRecurrent, LSTM @net type Recurrent Wxy; Wyy; by @@ -11,6 +11,23 @@ end Recurrent(in, out; init = initn) = Recurrent(init((in, out)), init((out, out)), init(out), init(out)) +@net type GatedRecurrent + Wxr; Wyr; br + Wxu; Wyu; bu + Wxh; Wyh; bh + y + function (x) + reset = σ( x * Wxr + y * Wyr + br ) + update = σ( x * Wxu + y * Wyu + bu ) + y′ = tanh( x * Wxh + (reset .* y) * Wyh + bh ) + y = (1 .- update) .* y′ + update .* y + end +end + +GatedRecurrent(in, out; init = initn) = + GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)..., + zeros(Float32, out)) + @net type LSTM Wxf; Wyf; bf Wxi; Wyi; bi @@ -32,20 +49,3 @@ end LSTM(in, out; init = initn) = LSTM(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:4]...)..., zeros(Float32, out), zeros(Float32, out)) - -@net type GatedRecurrent - Wxr; Wyr; br - Wxu; Wyu; bu - Wxh; Wyh; bh - state - function (x) - reset = σ( x * Wxr + y * Wyr + br ) - update = σ( x * Wxu + y * Wyu + bu ) - state′ = tanh( x * Wxh + (reset .* y) * Wyh + bh ) - state = (1 .- update) .* state′ + update .* y - end -end - -GatedRecurrent(in, out; init = initn) = - GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)..., - zeros(Float32, out))