mxnet broadcastto
This commit is contained in:
parent
4712568ac2
commit
e076bee00a
@ -10,7 +10,7 @@ using Base: @get!
|
||||
using DataFlow: Constant, constant
|
||||
using DataFlow.Interpreter
|
||||
using DataFlow.Interpreter: Exception, totrace
|
||||
using Flux: mapt
|
||||
using Flux: mapt, broadcastto
|
||||
|
||||
# TODO: implement Julia's type promotion rules
|
||||
|
||||
@ -30,6 +30,7 @@ graph(::typeof(vec), xs) = reshape(xs, shape = (-1,))
|
||||
graph(::typeof(broadcast), ::typeof(+), args...) = mx.broadcast_plus(args...)
|
||||
graph(::typeof(broadcast), ::typeof(*), args...) = mx.broadcast_mul(args...)
|
||||
graph(::typeof(broadcast), ::typeof(-), args...) = mx.broadcast_sub(args...)
|
||||
graph(::typeof(broadcastto), xs, shape) = mx.broadcast_to(xs, shape = map(i -> i≤1?0:i, reverse(shape)))
|
||||
# Old broadcasters
|
||||
graph(::typeof(broadcast), ::typeof(exp), xs) = exp(xs)
|
||||
graph(::typeof(.+), args...) = mx.broadcast_plus(args...)
|
||||
|
Loading…
Reference in New Issue
Block a user