throw GRU together
This commit is contained in:
parent
53ebb5051a
commit
85415d4244
@ -32,3 +32,20 @@ end
|
|||||||
LSTM(in, out; init = initn) =
|
LSTM(in, out; init = initn) =
|
||||||
LSTM(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:4]...)...,
|
LSTM(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:4]...)...,
|
||||||
zeros(Float32, out), zeros(Float32, out))
|
zeros(Float32, out), zeros(Float32, out))
|
||||||
|
|
||||||
|
@net type GatedRecurrent
|
||||||
|
Wxr; Wyr; br
|
||||||
|
Wxu; Wyu; bu
|
||||||
|
Wxh; Wyh; bh
|
||||||
|
state
|
||||||
|
function (x)
|
||||||
|
reset = σ( x * Wxr + y * Wyr + br )
|
||||||
|
update = σ( x * Wxu + y * Wyu + bu )
|
||||||
|
state′ = tanh( x * Wxh + (reset .* y) * Wyh + bh )
|
||||||
|
state = (1 .- update) .* state′ + update .* y
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
GatedRecurrent(in, out; init = initn) =
|
||||||
|
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)...,
|
||||||
|
zeros(Float32, out))
|
||||||
|
Loading…
Reference in New Issue
Block a user