From e076bee00a149d8ea595740fe82f2fc0b3b1dbf3 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 2 Jun 2017 13:44:29 +0100 Subject: [PATCH] mxnet broadcastto --- src/backend/mxnet/graph.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index ae7cc03a..18372f68 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -10,7 +10,7 @@ using Base: @get! using DataFlow: Constant, constant using DataFlow.Interpreter using DataFlow.Interpreter: Exception, totrace -using Flux: mapt +using Flux: mapt, broadcastto # TODO: implement Julia's type promotion rules @@ -30,6 +30,7 @@ graph(::typeof(vec), xs) = reshape(xs, shape = (-1,)) graph(::typeof(broadcast), ::typeof(+), args...) = mx.broadcast_plus(args...) graph(::typeof(broadcast), ::typeof(*), args...) = mx.broadcast_mul(args...) graph(::typeof(broadcast), ::typeof(-), args...) = mx.broadcast_sub(args...) +graph(::typeof(broadcastto), xs, shape) = mx.broadcast_to(xs, shape = map(i -> i≤1?0:i, reverse(shape))) # Old broadcasters graph(::typeof(broadcast), ::typeof(exp), xs) = exp(xs) graph(::typeof(.+), args...) = mx.broadcast_plus(args...)