Flux.jl/src/layers/recurrent.jl

52 lines
1.2 KiB
Julia
Raw Normal View History

2016-11-02 00:36:13 +00:00
export Recurrent, GatedRecurrent, LSTM
2016-10-30 11:41:52 +00:00
2016-10-28 23:13:32 +00:00
@net type Recurrent
2016-10-30 16:07:18 +00:00
Wxy; Wyy; by
y
2016-10-28 23:13:32 +00:00
function (x)
2016-10-30 16:07:18 +00:00
y = tanh( x * Wxy + y * Wyy + by )
2016-10-28 23:13:32 +00:00
end
end
2016-10-30 16:07:18 +00:00
Recurrent(in, out; init = initn) =
Recurrent(init((in, out)), init((out, out)), init(out), init(out))
2016-10-31 11:01:19 +00:00
2016-11-02 00:36:13 +00:00
@net type GatedRecurrent
Wxr; Wyr; br
Wxu; Wyu; bu
Wxh; Wyh; bh
y
function (x)
reset = σ( x * Wxr + y * Wyr + br )
update = σ( x * Wxu + y * Wyu + bu )
y = tanh( x * Wxh + (reset .* y) * Wyh + bh )
y = (1 .- update) .* y + update .* y
end
end
GatedRecurrent(in, out; init = initn) =
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)...,
zeros(Float32, out))
2016-10-31 11:01:19 +00:00
@net type LSTM
Wxf; Wyf; bf
Wxi; Wyi; bi
Wxo; Wyo; bo
Wxc; Wyc; bc
y; state
function (x)
# Gates
forget = σ( x * Wxf + y * Wyf + bf )
input = σ( x * Wxi + y * Wyi + bi )
output = σ( x * Wxo + y * Wyo + bo )
# State update and output
state = tanh( x * Wxc + y * Wyc + bc )
state = forget .* state + input .* state
y = output .* tanh(state)
end
end
LSTM(in, out; init = initn) =
LSTM(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:4]...)...,
zeros(Float32, out), zeros(Float32, out))