well, that was easy 😎
This commit is contained in:
parent
e35380940b
commit
7cd94b4a5d
|
@ -12,7 +12,8 @@ Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
|
|||
|
||||
model = Chain(
|
||||
Input(N),
|
||||
Recurrent(N, 128),
|
||||
LSTM(N, 128),
|
||||
LSTM(128, 128),
|
||||
Dense(128, N),
|
||||
softmax)
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
export Recurrent
|
||||
export Recurrent, LSTM
|
||||
|
||||
@net type Recurrent
|
||||
Wxy; Wyy; by
|
||||
|
@ -10,3 +10,25 @@ end
|
|||
|
||||
Recurrent(in, out; init = initn) =
|
||||
Recurrent(init((in, out)), init((out, out)), init(out), init(out))
|
||||
|
||||
@net type LSTM
|
||||
Wxf; Wyf; bf
|
||||
Wxi; Wyi; bi
|
||||
Wxo; Wyo; bo
|
||||
Wxc; Wyc; bc
|
||||
y; state
|
||||
function (x)
|
||||
# Gates
|
||||
forget = σ( x * Wxf + y * Wyf + bf )
|
||||
input = σ( x * Wxi + y * Wyi + bi )
|
||||
output = σ( x * Wxo + y * Wyo + bo )
|
||||
# State update and output
|
||||
state′ = tanh( x * Wxc + y * Wyc + bc )
|
||||
state = forget .* state + input .* state′
|
||||
y = output .* tanh(state)
|
||||
end
|
||||
end
|
||||
|
||||
LSTM(in, out; init = initn) =
|
||||
LSTM(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:4]...)...,
|
||||
zeros(Float32, out), zeros(Float32, out))
|
||||
|
|
Loading…
Reference in New Issue