fixes for recurrent networks

This commit is contained in:
Mike J Innes 2017-04-19 17:17:37 +01:00
parent 358334a893
commit 14afe54143
2 changed files with 5 additions and 3 deletions

View File

@ -19,6 +19,8 @@ node(x::mx.SymbolicNode) = x
graph(::typeof(tuple), args...) = (args...,) graph(::typeof(tuple), args...) = (args...,)
graph(::typeof(.+), args...) = mx.broadcast_plus(args...) graph(::typeof(.+), args...) = mx.broadcast_plus(args...)
graph(::typeof(.*), args...) = mx.broadcast_mul(args...)
graph(::typeof(.-), args...) = mx.broadcast_sub(args...)
graph(::typeof(*), xs...) = mx.dot(reverse(xs)...) # Work around MXNet shape hack graph(::typeof(*), xs...) = mx.dot(reverse(xs)...) # Work around MXNet shape hack
graph(::typeof(σ), x) = mx.Activation(x, act_type = :sigmoid) graph(::typeof(σ), x) = mx.Activation(x, act_type = :sigmoid)
graph(::typeof(relu), x) = mx.Activation(x, act_type = :relu) graph(::typeof(relu), x) = mx.Activation(x, act_type = :relu)

View File

@ -25,8 +25,8 @@ Recurrent(in, out; init = initn) =
end end
GatedRecurrent(in, out; init = initn) = GatedRecurrent(in, out; init = initn) =
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)..., GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(1, out)] for _ = 1:3]...)...,
zeros(Float32, out)) zeros(Float32, (1, out)))
@net type LSTM @net type LSTM
Wxf; Wyf; bf Wxf; Wyf; bf
@ -48,4 +48,4 @@ end
LSTM(in, out; init = initn) = LSTM(in, out; init = initn) =
LSTM(vcat([[init((in, out)), init((out, out)), init((1, out))] for _ = 1:4]...)..., LSTM(vcat([[init((in, out)), init((out, out)), init((1, out))] for _ = 1:4]...)...,
zeros(Float32, out), zeros(Float32, out)) zeros(Float32, (1, out)), zeros(Float32, (1, out)))