diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index ae7cc03a..18372f68 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -10,7 +10,7 @@ using Base: @get! using DataFlow: Constant, constant using DataFlow.Interpreter using DataFlow.Interpreter: Exception, totrace -using Flux: mapt +using Flux: mapt, broadcastto # TODO: implement Julia's type promotion rules @@ -30,6 +30,7 @@ graph(::typeof(vec), xs) = reshape(xs, shape = (-1,)) graph(::typeof(broadcast), ::typeof(+), args...) = mx.broadcast_plus(args...) graph(::typeof(broadcast), ::typeof(*), args...) = mx.broadcast_mul(args...) graph(::typeof(broadcast), ::typeof(-), args...) = mx.broadcast_sub(args...) +graph(::typeof(broadcastto), xs, shape) = mx.broadcast_to(xs, shape = map(i -> i≤1?0:i, reverse(shape))) # Old broadcasters graph(::typeof(broadcast), ::typeof(exp), xs) = exp(xs) graph(::typeof(.+), args...) = mx.broadcast_plus(args...)