fixes for recurrent networks
This commit is contained in:
parent
358334a893
commit
14afe54143
@ -19,6 +19,8 @@ node(x::mx.SymbolicNode) = x
|
|||||||
|
|
||||||
graph(::typeof(tuple), args...) = (args...,)
|
graph(::typeof(tuple), args...) = (args...,)
|
||||||
graph(::typeof(.+), args...) = mx.broadcast_plus(args...)
|
graph(::typeof(.+), args...) = mx.broadcast_plus(args...)
|
||||||
|
graph(::typeof(.*), args...) = mx.broadcast_mul(args...)
|
||||||
|
graph(::typeof(.-), args...) = mx.broadcast_sub(args...)
|
||||||
graph(::typeof(*), xs...) = mx.dot(reverse(xs)...) # Work around MXNet shape hack
|
graph(::typeof(*), xs...) = mx.dot(reverse(xs)...) # Work around MXNet shape hack
|
||||||
graph(::typeof(σ), x) = mx.Activation(x, act_type = :sigmoid)
|
graph(::typeof(σ), x) = mx.Activation(x, act_type = :sigmoid)
|
||||||
graph(::typeof(relu), x) = mx.Activation(x, act_type = :relu)
|
graph(::typeof(relu), x) = mx.Activation(x, act_type = :relu)
|
||||||
|
@ -25,8 +25,8 @@ Recurrent(in, out; init = initn) =
|
|||||||
end
|
end
|
||||||
|
|
||||||
GatedRecurrent(in, out; init = initn) =
|
GatedRecurrent(in, out; init = initn) =
|
||||||
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)...,
|
GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(1, out)] for _ = 1:3]...)...,
|
||||||
zeros(Float32, out))
|
zeros(Float32, (1, out)))
|
||||||
|
|
||||||
@net type LSTM
|
@net type LSTM
|
||||||
Wxf; Wyf; bf
|
Wxf; Wyf; bf
|
||||||
@ -48,4 +48,4 @@ end
|
|||||||
|
|
||||||
LSTM(in, out; init = initn) =
|
LSTM(in, out; init = initn) =
|
||||||
LSTM(vcat([[init((in, out)), init((out, out)), init((1, out))] for _ = 1:4]...)...,
|
LSTM(vcat([[init((in, out)), init((out, out)), init((1, out))] for _ = 1:4]...)...,
|
||||||
zeros(Float32, out), zeros(Float32, out))
|
zeros(Float32, (1, out)), zeros(Float32, (1, out)))
|
||||||
|
Loading…
Reference in New Issue
Block a user