more mxnet ops
This commit is contained in:
parent
e076bee00a
commit
edb1099cec
@ -10,7 +10,7 @@ using Base: @get!
|
|||||||
using DataFlow: Constant, constant
|
using DataFlow: Constant, constant
|
||||||
using DataFlow.Interpreter
|
using DataFlow.Interpreter
|
||||||
using DataFlow.Interpreter: Exception, totrace
|
using DataFlow.Interpreter: Exception, totrace
|
||||||
using Flux: mapt, broadcastto
|
import Flux: mapt, broadcastto, ∘
|
||||||
|
|
||||||
# TODO: implement Julia's type promotion rules
|
# TODO: implement Julia's type promotion rules
|
||||||
|
|
||||||
@ -28,8 +28,9 @@ graph(::typeof(hcat), xs...) = mx.concat(xs..., dim = 2-1)
|
|||||||
graph(::typeof(vec), xs) = reshape(xs, shape = (-1,))
|
graph(::typeof(vec), xs) = reshape(xs, shape = (-1,))
|
||||||
|
|
||||||
graph(::typeof(broadcast), ::typeof(+), args...) = mx.broadcast_plus(args...)
|
graph(::typeof(broadcast), ::typeof(+), args...) = mx.broadcast_plus(args...)
|
||||||
graph(::typeof(broadcast), ::typeof(*), args...) = mx.broadcast_mul(args...)
|
|
||||||
graph(::typeof(broadcast), ::typeof(-), args...) = mx.broadcast_sub(args...)
|
graph(::typeof(broadcast), ::typeof(-), args...) = mx.broadcast_sub(args...)
|
||||||
|
graph(::typeof(broadcast), ::typeof(*), args...) = mx.broadcast_mul(args...)
|
||||||
|
graph(::typeof(broadcast), ::typeof(/), args...) = mx.broadcast_div(args...)
|
||||||
graph(::typeof(broadcastto), xs, shape) = mx.broadcast_to(xs, shape = map(i -> i≤1?0:i, reverse(shape)))
|
graph(::typeof(broadcastto), xs, shape) = mx.broadcast_to(xs, shape = map(i -> i≤1?0:i, reverse(shape)))
|
||||||
# Old broadcasters
|
# Old broadcasters
|
||||||
graph(::typeof(broadcast), ::typeof(exp), xs) = exp(xs)
|
graph(::typeof(broadcast), ::typeof(exp), xs) = exp(xs)
|
||||||
@ -47,6 +48,8 @@ graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
|
|||||||
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
|
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
|
||||||
graph(::typeof(sum), xs::Tuple) = reduce((a, b) -> graph(broadcast, +, a, b), xs)
|
graph(::typeof(sum), xs::Tuple) = reduce((a, b) -> graph(broadcast, +, a, b), xs)
|
||||||
|
|
||||||
|
a::mx.SymbolicNode ∘ b::mx.SymbolicNode = mx.broadcast_mul(a, b)
|
||||||
|
|
||||||
graph(::Input, x) = x
|
graph(::Input, x) = x
|
||||||
|
|
||||||
struct AlterParam
|
struct AlterParam
|
||||||
|
Loading…
Reference in New Issue
Block a user