mxnet broadcastto

This commit is contained in:
Mike J Innes 2017-06-02 13:44:29 +01:00
parent 4712568ac2
commit e076bee00a

View File

@ -10,7 +10,7 @@ using Base: @get!
using DataFlow: Constant, constant
using DataFlow.Interpreter
using DataFlow.Interpreter: Exception, totrace
using Flux: mapt
using Flux: mapt, broadcastto
# TODO: implement Julia's type promotion rules
@ -30,6 +30,7 @@ graph(::typeof(vec), xs) = reshape(xs, shape = (-1,))
graph(::typeof(broadcast), ::typeof(+), args...) = mx.broadcast_plus(args...)
graph(::typeof(broadcast), ::typeof(*), args...) = mx.broadcast_mul(args...)
graph(::typeof(broadcast), ::typeof(-), args...) = mx.broadcast_sub(args...)
graph(::typeof(broadcastto), xs, shape) = mx.broadcast_to(xs, shape = map(i -> i≤1?0:i, reverse(shape)))
# Old broadcasters
graph(::typeof(broadcast), ::typeof(exp), xs) = exp(xs)
graph(::typeof(.+), args...) = mx.broadcast_plus(args...)