actually get GRU working
This commit is contained in:
parent
85415d4244
commit
d7d95feab8
@ -2,14 +2,18 @@ import Base: @get!
|
|||||||
import DataFlow: Constant, postwalk, value, inputs, constant
|
import DataFlow: Constant, postwalk, value, inputs, constant
|
||||||
import TensorFlow: RawTensor
|
import TensorFlow: RawTensor
|
||||||
|
|
||||||
|
# TODO: implement Julia's type promotion rules
|
||||||
|
|
||||||
cvalue(x) = x
|
cvalue(x) = x
|
||||||
cvalue(c::Constant) = c.value
|
cvalue(c::Constant) = c.value
|
||||||
cvalue(v::Vertex) = cvalue(value(v))
|
cvalue(v::Vertex) = cvalue(value(v))
|
||||||
|
|
||||||
graph(x::Tensor) = x
|
graph(x::Tensor) = x
|
||||||
|
graph(x::Number) = TensorFlow.constant(Float32(x))
|
||||||
|
|
||||||
graph(::typeof(*), args...) = *(args...)
|
graph(::typeof(*), args...) = *(args...)
|
||||||
graph(::typeof(.*), args...) = .*(args...)
|
graph(::typeof(.*), args...) = .*(args...)
|
||||||
|
graph(::typeof(.-), args...) = -(args...)
|
||||||
graph(::typeof(+), args...) = +(args...)
|
graph(::typeof(+), args...) = +(args...)
|
||||||
graph(::typeof(softmax), x) = nn.softmax(x)
|
graph(::typeof(softmax), x) = nn.softmax(x)
|
||||||
graph(::typeof(relu), x) = nn.relu(x)
|
graph(::typeof(relu), x) = nn.relu(x)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
export Recurrent, LSTM
|
export Recurrent, GatedRecurrent, LSTM
|
||||||
|
|
||||||
@net type Recurrent
|
@net type Recurrent
|
||||||
Wxy; Wyy; by
|
Wxy; Wyy; by
|
||||||
@ -11,6 +11,23 @@ end
|
|||||||
Recurrent(in, out; init = initn) =
|
Recurrent(in, out; init = initn) =
|
||||||
Recurrent(init((in, out)), init((out, out)), init(out), init(out))
|
Recurrent(init((in, out)), init((out, out)), init(out), init(out))
|
||||||
|
|
||||||
|
@net type GatedRecurrent
|
||||||
|
Wxr; Wyr; br
|
||||||
|
Wxu; Wyu; bu
|
||||||
|
Wxh; Wyh; bh
|
||||||
|
y
|
||||||
|
function (x)
|
||||||
|
reset = σ( x * Wxr + y * Wyr + br )
|
||||||
|
update = σ( x * Wxu + y * Wyu + bu )
|
||||||
|
y′ = tanh( x * Wxh + (reset .* y) * Wyh + bh )
|
||||||
|
y = (1 .- update) .* y′ + update .* y
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
GatedRecurrent(in, out; init = initn) =
|
||||||
|
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)...,
|
||||||
|
zeros(Float32, out))
|
||||||
|
|
||||||
@net type LSTM
|
@net type LSTM
|
||||||
Wxf; Wyf; bf
|
Wxf; Wyf; bf
|
||||||
Wxi; Wyi; bi
|
Wxi; Wyi; bi
|
||||||
@ -32,20 +49,3 @@ end
|
|||||||
LSTM(in, out; init = initn) =
|
LSTM(in, out; init = initn) =
|
||||||
LSTM(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:4]...)...,
|
LSTM(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:4]...)...,
|
||||||
zeros(Float32, out), zeros(Float32, out))
|
zeros(Float32, out), zeros(Float32, out))
|
||||||
|
|
||||||
@net type GatedRecurrent
|
|
||||||
Wxr; Wyr; br
|
|
||||||
Wxu; Wyu; bu
|
|
||||||
Wxh; Wyh; bh
|
|
||||||
state
|
|
||||||
function (x)
|
|
||||||
reset = σ( x * Wxr + y * Wyr + br )
|
|
||||||
update = σ( x * Wxu + y * Wyu + bu )
|
|
||||||
state′ = tanh( x * Wxh + (reset .* y) * Wyh + bh )
|
|
||||||
state = (1 .- update) .* state′ + update .* y
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
GatedRecurrent(in, out; init = initn) =
|
|
||||||
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)...,
|
|
||||||
zeros(Float32, out))
|
|
||||||
|
Loading…
Reference in New Issue
Block a user