diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 313519a0..c6f17e1e 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -43,6 +43,7 @@ graph(::typeof(cat), dim::Integer, a...) = mx.Concat(a..., dim = dim) graph(::typeof(vcat), a...) = graph(cat, 1, a...) graph(::typeof(map), f, xss::Tuple...) = map(f, xss...) +graph(::typeof(sum), xs::Tuple) = reduce((a, b) -> graph(broadcast, +, a, b), xs) graph(::Input, x) = x