actually get GRU working

This commit is contained in:
Mike J Innes 2016-11-02 00:36:13 +00:00
parent 85415d4244
commit d7d95feab8
2 changed files with 22 additions and 18 deletions

View File

@ -2,14 +2,18 @@ import Base: @get!
import DataFlow: Constant, postwalk, value, inputs, constant
import TensorFlow: RawTensor
# TODO: implement Julia's type promotion rules
cvalue(x) = x
cvalue(c::Constant) = c.value
cvalue(v::Vertex) = cvalue(value(v))
graph(x::Tensor) = x
graph(x::Number) = TensorFlow.constant(Float32(x))
graph(::typeof(*), args...) = *(args...)
graph(::typeof(.*), args...) = .*(args...)
graph(::typeof(.-), args...) = -(args...)
graph(::typeof(+), args...) = +(args...)
graph(::typeof(softmax), x) = nn.softmax(x)
graph(::typeof(relu), x) = nn.relu(x)

View File

@ -1,4 +1,4 @@
export Recurrent, LSTM
export Recurrent, GatedRecurrent, LSTM
@net type Recurrent
Wxy; Wyy; by
@ -11,6 +11,23 @@ end
Recurrent(in, out; init = initn) =
Recurrent(init((in, out)), init((out, out)), init(out), init(out))
@net type GatedRecurrent
Wxr; Wyr; br
Wxu; Wyu; bu
Wxh; Wyh; bh
y
function (x)
reset = σ( x * Wxr + y * Wyr + br )
update = σ( x * Wxu + y * Wyu + bu )
y = tanh( x * Wxh + (reset .* y) * Wyh + bh )
y = (1 .- update) .* y + update .* y
end
end
GatedRecurrent(in, out; init = initn) =
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)...,
zeros(Float32, out))
@net type LSTM
Wxf; Wyf; bf
Wxi; Wyi; bi
@ -32,20 +49,3 @@ end
LSTM(in, out; init = initn) =
LSTM(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:4]...)...,
zeros(Float32, out), zeros(Float32, out))
@net type GatedRecurrent
Wxr; Wyr; br
Wxu; Wyu; bu
Wxh; Wyh; bh
state
function (x)
reset = σ( x * Wxr + y * Wyr + br )
update = σ( x * Wxu + y * Wyu + bu )
state = tanh( x * Wxh + (reset .* y) * Wyh + bh )
state = (1 .- update) .* state + update .* y
end
end
GatedRecurrent(in, out; init = initn) =
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)...,
zeros(Float32, out))