make these fit with julia semantics

This commit is contained in:
Mike J Innes 2016-11-15 16:40:17 +00:00
parent b7caaf4a65
commit f31b539566
2 changed files with 8 additions and 8 deletions

View File

@ -5,7 +5,7 @@ export Affine
@net type Affine
W
b
x -> x*W + b
x -> x*W .+ b
end
Affine(in::Integer, out::Integer; init = initn) =

View File

@ -36,16 +36,16 @@ GatedRecurrent(in, out; init = initn) =
y; state
function (x)
# Gates
forget = σ( x * Wxf + y{-1} * Wyf + bf )
input = σ( x * Wxi + y{-1} * Wyi + bi )
output = σ( x * Wxo + y{-1} * Wyo + bo )
forget = σ( x * Wxf .+ y{-1} * Wyf .+ bf )
input = σ( x * Wxi .+ y{-1} * Wyi .+ bi )
output = σ( x * Wxo .+ y{-1} * Wyo .+ bo )
# State update and output
state = tanh( x * Wxc + y{-1} * Wyc + bc )
state = forget .* state{-1} + input .* state
state = tanh( x * Wxc .+ y{-1} * Wyc .+ bc )
state = forget .* state{-1} .+ input .* state
y = output .* tanh(state)
end
end
LSTM(in, out; init = initn) =
LSTM(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:4]...)...,
zeros(Float32, out), zeros(Float32, out))
LSTM(vcat([[init((in, out)), init((out, out)), init((1, out))] for _ = 1:4]...)...,
zeros(Float32, 1, out), zeros(Float32, 1, out))