make these fit with julia semantics
This commit is contained in:
parent
b7caaf4a65
commit
f31b539566
@ -5,7 +5,7 @@ export Affine
|
|||||||
@net type Affine
|
@net type Affine
|
||||||
W
|
W
|
||||||
b
|
b
|
||||||
x -> x*W + b
|
x -> x*W .+ b
|
||||||
end
|
end
|
||||||
|
|
||||||
Affine(in::Integer, out::Integer; init = initn) =
|
Affine(in::Integer, out::Integer; init = initn) =
|
||||||
|
@ -36,16 +36,16 @@ GatedRecurrent(in, out; init = initn) =
|
|||||||
y; state
|
y; state
|
||||||
function (x)
|
function (x)
|
||||||
# Gates
|
# Gates
|
||||||
forget = σ( x * Wxf + y{-1} * Wyf + bf )
|
forget = σ( x * Wxf .+ y{-1} * Wyf .+ bf )
|
||||||
input = σ( x * Wxi + y{-1} * Wyi + bi )
|
input = σ( x * Wxi .+ y{-1} * Wyi .+ bi )
|
||||||
output = σ( x * Wxo + y{-1} * Wyo + bo )
|
output = σ( x * Wxo .+ y{-1} * Wyo .+ bo )
|
||||||
# State update and output
|
# State update and output
|
||||||
state′ = tanh( x * Wxc + y{-1} * Wyc + bc )
|
state′ = tanh( x * Wxc .+ y{-1} * Wyc .+ bc )
|
||||||
state = forget .* state{-1} + input .* state′
|
state = forget .* state{-1} .+ input .* state′
|
||||||
y = output .* tanh(state)
|
y = output .* tanh(state)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
LSTM(in, out; init = initn) =
|
LSTM(in, out; init = initn) =
|
||||||
LSTM(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:4]...)...,
|
LSTM(vcat([[init((in, out)), init((out, out)), init((1, out))] for _ = 1:4]...)...,
|
||||||
zeros(Float32, out), zeros(Float32, out))
|
zeros(Float32, 1, out), zeros(Float32, 1, out))
|
||||||
|
Loading…
Reference in New Issue
Block a user