52 lines
1.3 KiB
Julia
52 lines
1.3 KiB
Julia
export Recurrent, GatedRecurrent, LSTM
|
||
|
||
@net type Recurrent
|
||
Wxy; Wyy; by
|
||
y
|
||
function (x)
|
||
y = tanh( x * Wxy .+ y{-1} * Wyy .+ by )
|
||
end
|
||
end
|
||
|
||
Recurrent(in, out; init = initn) =
|
||
Recurrent(init((in, out)), init((out, out)), init(1, out), init(1, out))
|
||
|
||
@net type GatedRecurrent
|
||
Wxr; Wyr; br
|
||
Wxu; Wyu; bu
|
||
Wxh; Wyh; bh
|
||
y
|
||
function (x)
|
||
reset = σ( x * Wxr .+ y{-1} * Wyr .+ br )
|
||
update = σ( x * Wxu .+ y{-1} * Wyu .+ bu )
|
||
y′ = tanh( x * Wxh .+ (reset .* y{-1}) * Wyh .+ bh )
|
||
y = (1 .- update) .* y′ .+ update .* y{-1}
|
||
end
|
||
end
|
||
|
||
GatedRecurrent(in, out; init = initn) =
|
||
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(1, out)] for _ = 1:3]...)...,
|
||
zeros(Float32, (1, out)))
|
||
|
||
@net type LSTM
|
||
Wxf; Wyf; bf
|
||
Wxi; Wyi; bi
|
||
Wxo; Wyo; bo
|
||
Wxc; Wyc; bc
|
||
y; state
|
||
function (x)
|
||
# Gates
|
||
forget = σ( x * Wxf .+ y{-1} * Wyf .+ bf )
|
||
input = σ( x * Wxi .+ y{-1} * Wyi .+ bi )
|
||
output = σ( x * Wxo .+ y{-1} * Wyo .+ bo )
|
||
# State update and output
|
||
state′ = tanh( x * Wxc .+ y{-1} * Wyc .+ bc )
|
||
state = forget .* state{-1} .+ input .* state′
|
||
y = output .* tanh(state)
|
||
end
|
||
end
|
||
|
||
LSTM(in, out; init = initn) =
|
||
LSTM(vcat([[init((in, out)), init((out, out)), init((1, out))] for _ = 1:4]...)...,
|
||
zeros(Float32, (1, out)), zeros(Float32, (1, out)))
|