fixes for recurrent networks
This commit is contained in:
parent
358334a893
commit
14afe54143
@ -19,6 +19,8 @@ node(x::mx.SymbolicNode) = x
|
||||
|
||||
graph(::typeof(tuple), args...) = (args...,)
|
||||
graph(::typeof(.+), args...) = mx.broadcast_plus(args...)
|
||||
graph(::typeof(.*), args...) = mx.broadcast_mul(args...)
|
||||
graph(::typeof(.-), args...) = mx.broadcast_sub(args...)
|
||||
graph(::typeof(*), xs...) = mx.dot(reverse(xs)...) # Work around MXNet shape hack
|
||||
graph(::typeof(σ), x) = mx.Activation(x, act_type = :sigmoid)
|
||||
graph(::typeof(relu), x) = mx.Activation(x, act_type = :relu)
|
||||
|
@ -25,8 +25,8 @@ Recurrent(in, out; init = initn) =
|
||||
end
|
||||
|
||||
GatedRecurrent(in, out; init = initn) =
|
||||
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)...,
|
||||
zeros(Float32, out))
|
||||
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(1, out)] for _ = 1:3]...)...,
|
||||
zeros(Float32, (1, out)))
|
||||
|
||||
@net type LSTM
|
||||
Wxf; Wyf; bf
|
||||
@ -48,4 +48,4 @@ end
|
||||
|
||||
LSTM(in, out; init = initn) =
|
||||
LSTM(vcat([[init((in, out)), init((out, out)), init((1, out))] for _ = 1:4]...)...,
|
||||
zeros(Float32, out), zeros(Float32, out))
|
||||
zeros(Float32, (1, out)), zeros(Float32, (1, out)))
|
||||
|
Loading…
Reference in New Issue
Block a user