more mxnet ops
This commit is contained in:
parent
e076bee00a
commit
edb1099cec
@ -10,7 +10,7 @@ using Base: @get!
|
||||
using DataFlow: Constant, constant
|
||||
using DataFlow.Interpreter
|
||||
using DataFlow.Interpreter: Exception, totrace
|
||||
using Flux: mapt, broadcastto
|
||||
import Flux: mapt, broadcastto, ∘
|
||||
|
||||
# TODO: implement Julia's type promotion rules
|
||||
|
||||
@ -28,8 +28,9 @@ graph(::typeof(hcat), xs...) = mx.concat(xs..., dim = 2-1)
|
||||
graph(::typeof(vec), xs) = reshape(xs, shape = (-1,))
|
||||
|
||||
graph(::typeof(broadcast), ::typeof(+), args...) = mx.broadcast_plus(args...)
|
||||
graph(::typeof(broadcast), ::typeof(*), args...) = mx.broadcast_mul(args...)
|
||||
graph(::typeof(broadcast), ::typeof(-), args...) = mx.broadcast_sub(args...)
|
||||
graph(::typeof(broadcast), ::typeof(*), args...) = mx.broadcast_mul(args...)
|
||||
graph(::typeof(broadcast), ::typeof(/), args...) = mx.broadcast_div(args...)
|
||||
graph(::typeof(broadcastto), xs, shape) = mx.broadcast_to(xs, shape = map(i -> i≤1?0:i, reverse(shape)))
|
||||
# Old broadcasters
|
||||
graph(::typeof(broadcast), ::typeof(exp), xs) = exp(xs)
|
||||
@ -47,6 +48,8 @@ graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
|
||||
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
|
||||
graph(::typeof(sum), xs::Tuple) = reduce((a, b) -> graph(broadcast, +, a, b), xs)
|
||||
|
||||
a::mx.SymbolicNode ∘ b::mx.SymbolicNode = mx.broadcast_mul(a, b)
|
||||
|
||||
graph(::Input, x) = x
|
||||
|
||||
struct AlterParam
|
||||
|
Loading…
Reference in New Issue
Block a user