actually get GRU working
This commit is contained in:
parent
85415d4244
commit
d7d95feab8
@ -2,14 +2,18 @@ import Base: @get!
|
||||
import DataFlow: Constant, postwalk, value, inputs, constant
|
||||
import TensorFlow: RawTensor
|
||||
|
||||
# TODO: implement Julia's type promotion rules
|
||||
|
||||
cvalue(x) = x
|
||||
cvalue(c::Constant) = c.value
|
||||
cvalue(v::Vertex) = cvalue(value(v))
|
||||
|
||||
graph(x::Tensor) = x
|
||||
graph(x::Number) = TensorFlow.constant(Float32(x))
|
||||
|
||||
graph(::typeof(*), args...) = *(args...)
|
||||
graph(::typeof(.*), args...) = .*(args...)
|
||||
graph(::typeof(.-), args...) = -(args...)
|
||||
graph(::typeof(+), args...) = +(args...)
|
||||
graph(::typeof(softmax), x) = nn.softmax(x)
|
||||
graph(::typeof(relu), x) = nn.relu(x)
|
||||
|
@ -1,4 +1,4 @@
|
||||
export Recurrent, LSTM
|
||||
export Recurrent, GatedRecurrent, LSTM
|
||||
|
||||
@net type Recurrent
|
||||
Wxy; Wyy; by
|
||||
@ -11,6 +11,23 @@ end
|
||||
Recurrent(in, out; init = initn) =
|
||||
Recurrent(init((in, out)), init((out, out)), init(out), init(out))
|
||||
|
||||
@net type GatedRecurrent
|
||||
Wxr; Wyr; br
|
||||
Wxu; Wyu; bu
|
||||
Wxh; Wyh; bh
|
||||
y
|
||||
function (x)
|
||||
reset = σ( x * Wxr + y * Wyr + br )
|
||||
update = σ( x * Wxu + y * Wyu + bu )
|
||||
y′ = tanh( x * Wxh + (reset .* y) * Wyh + bh )
|
||||
y = (1 .- update) .* y′ + update .* y
|
||||
end
|
||||
end
|
||||
|
||||
GatedRecurrent(in, out; init = initn) =
|
||||
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)...,
|
||||
zeros(Float32, out))
|
||||
|
||||
@net type LSTM
|
||||
Wxf; Wyf; bf
|
||||
Wxi; Wyi; bi
|
||||
@ -32,20 +49,3 @@ end
|
||||
LSTM(in, out; init = initn) =
|
||||
LSTM(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:4]...)...,
|
||||
zeros(Float32, out), zeros(Float32, out))
|
||||
|
||||
@net type GatedRecurrent
|
||||
Wxr; Wyr; br
|
||||
Wxu; Wyu; bu
|
||||
Wxh; Wyh; bh
|
||||
state
|
||||
function (x)
|
||||
reset = σ( x * Wxr + y * Wyr + br )
|
||||
update = σ( x * Wxu + y * Wyu + bu )
|
||||
state′ = tanh( x * Wxh + (reset .* y) * Wyh + bh )
|
||||
state = (1 .- update) .* state′ + update .* y
|
||||
end
|
||||
end
|
||||
|
||||
GatedRecurrent(in, out; init = initn) =
|
||||
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)...,
|
||||
zeros(Float32, out))
|
||||
|
Loading…
Reference in New Issue
Block a user