From 14afe54143da54870e0e7a428addab607d91a69b Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 19 Apr 2017 17:17:37 +0100 Subject: [PATCH] fixes for recurrent networks --- src/backend/mxnet/graph.jl | 2 ++ src/layers/recurrent.jl | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index d557f68b..edc00491 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -19,6 +19,8 @@ node(x::mx.SymbolicNode) = x graph(::typeof(tuple), args...) = (args...,) graph(::typeof(.+), args...) = mx.broadcast_plus(args...) +graph(::typeof(.*), args...) = mx.broadcast_mul(args...) +graph(::typeof(.-), args...) = mx.broadcast_sub(args...) graph(::typeof(*), xs...) = mx.dot(reverse(xs)...) # Work around MXNet shape hack graph(::typeof(σ), x) = mx.Activation(x, act_type = :sigmoid) graph(::typeof(relu), x) = mx.Activation(x, act_type = :relu) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 73228e9f..3874287b 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -25,8 +25,8 @@ Recurrent(in, out; init = initn) = end GatedRecurrent(in, out; init = initn) = - GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(out)] for _ = 1:3]...)..., - zeros(Float32, out)) + GatedRecurrent(vcat([[init((in, out)), init((out, out)), init(1, out)] for _ = 1:3]...)..., + zeros(Float32, (1, out))) @net type LSTM Wxf; Wyf; bf @@ -48,4 +48,4 @@ end LSTM(in, out; init = initn) = LSTM(vcat([[init((in, out)), init((out, out)), init((1, out))] for _ = 1:4]...)..., - zeros(Float32, out), zeros(Float32, out)) + zeros(Float32, (1, out)), zeros(Float32, (1, out)))