diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 44b7c5cd..84658a34 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -30,7 +30,7 @@ graph(::typeof(softmax), xs) = mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1))) graph(::typeof(cat), dim::Integer, a...) = mx.Concat(a..., dim = dim) -graph(::typeof(vcat), a...) = node(cat, 1, a...) +graph(::typeof(vcat), a...) = graph(cat, 1, a...) graph(::Input, x) = x