2016-11-02 00:36:13 +00:00
|
|
|
|
export Recurrent, GatedRecurrent, LSTM
|
2016-10-30 11:41:52 +00:00
|
|
|
|
|
2016-10-28 23:13:32 +00:00
|
|
|
|
@net type Recurrent
|
2016-10-30 16:07:18 +00:00
|
|
|
|
Wxy; Wyy; by
|
|
|
|
|
y
|
2016-10-28 23:13:32 +00:00
|
|
|
|
function (x)
|
2016-11-08 18:08:13 +00:00
|
|
|
|
y = tanh( x * Wxy + y{-1} * Wyy + by )
|
2016-10-28 23:13:32 +00:00
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
2016-10-30 16:07:18 +00:00
|
|
|
|
Recurrent(in, out; init = initn) =
|
|
|
|
|
Recurrent(init((in, out)), init((out, out)), init(out), init(out))
|
2016-10-31 11:01:19 +00:00
|
|
|
|
|
2016-11-02 00:36:13 +00:00
|
|
|
|
@net type GatedRecurrent
|
|
|
|
|
Wxr; Wyr; br
|
|
|
|
|
Wxu; Wyu; bu
|
|
|
|
|
Wxh; Wyh; bh
|
|
|
|
|
y
|
|
|
|
|
function (x)
|
2016-11-08 18:08:13 +00:00
|
|
|
|
reset = σ( x * Wxr + y{-1} * Wyr + br )
|
|
|
|
|
update = σ( x * Wxu + y{-1} * Wyu + bu )
|
|
|
|
|
y′ = tanh( x * Wxh + (reset .* y{-1}) * Wyh + bh )
|
|
|
|
|
y = (1 .- update) .* y′ + update .* y{-1}
|
2016-11-02 00:36:13 +00:00
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
GatedRecurrent(in, out; init = initn) =
|
|
|
|
|
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)...,
|
|
|
|
|
zeros(Float32, out))
|
|
|
|
|
|
2016-10-31 11:01:19 +00:00
|
|
|
|
@net type LSTM
|
|
|
|
|
Wxf; Wyf; bf
|
|
|
|
|
Wxi; Wyi; bi
|
|
|
|
|
Wxo; Wyo; bo
|
|
|
|
|
Wxc; Wyc; bc
|
|
|
|
|
y; state
|
|
|
|
|
function (x)
|
|
|
|
|
# Gates
|
2016-11-15 16:40:17 +00:00
|
|
|
|
forget = σ( x * Wxf .+ y{-1} * Wyf .+ bf )
|
|
|
|
|
input = σ( x * Wxi .+ y{-1} * Wyi .+ bi )
|
|
|
|
|
output = σ( x * Wxo .+ y{-1} * Wyo .+ bo )
|
2016-10-31 11:01:19 +00:00
|
|
|
|
# State update and output
|
2016-11-15 16:40:17 +00:00
|
|
|
|
state′ = tanh( x * Wxc .+ y{-1} * Wyc .+ bc )
|
|
|
|
|
state = forget .* state{-1} .+ input .* state′
|
2016-10-31 11:01:19 +00:00
|
|
|
|
y = output .* tanh(state)
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
LSTM(in, out; init = initn) =
|
2016-11-15 16:40:17 +00:00
|
|
|
|
LSTM(vcat([[init((in, out)), init((out, out)), init((1, out))] for _ = 1:4]...)...,
|
2016-12-15 20:53:08 +00:00
|
|
|
|
zeros(Float32, out), zeros(Float32, out))
|