well, that was easy 😎

This commit is contained in:
Mike J Innes 2016-10-31 11:01:19 +00:00
parent e35380940b
commit 7cd94b4a5d
2 changed files with 25 additions and 2 deletions

View File

@ -12,7 +12,8 @@ Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
model = Chain(
Input(N),
Recurrent(N, 128),
LSTM(N, 128),
LSTM(128, 128),
Dense(128, N),
softmax)

View File

@ -1,4 +1,4 @@
export Recurrent
export Recurrent, LSTM
@net type Recurrent
Wxy; Wyy; by
@ -10,3 +10,25 @@ end
Recurrent(in, out; init = initn) =
Recurrent(init((in, out)), init((out, out)), init(out), init(out))
@net type LSTM
Wxf; Wyf; bf
Wxi; Wyi; bi
Wxo; Wyo; bo
Wxc; Wyc; bc
y; state
function (x)
# Gates
forget = σ( x * Wxf + y * Wyf + bf )
input = σ( x * Wxi + y * Wyi + bi )
output = σ( x * Wxo + y * Wyo + bo )
# State update and output
state = tanh( x * Wxc + y * Wyc + bc )
state = forget .* state + input .* state
y = output .* tanh(state)
end
end
LSTM(in, out; init = initn) =
LSTM(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:4]...)...,
zeros(Float32, out), zeros(Float32, out))