more mxnet ops

This commit is contained in:
Mike J Innes 2017-06-02 14:42:15 +01:00
parent e076bee00a
commit edb1099cec

View File

@ -10,7 +10,7 @@ using Base: @get!
using DataFlow: Constant, constant
using DataFlow.Interpreter
using DataFlow.Interpreter: Exception, totrace
using Flux: mapt, broadcastto
import Flux: mapt, broadcastto,
# TODO: implement Julia's type promotion rules
@ -28,8 +28,9 @@ graph(::typeof(hcat), xs...) = mx.concat(xs..., dim = 2-1)
graph(::typeof(vec), xs) = reshape(xs, shape = (-1,))
graph(::typeof(broadcast), ::typeof(+), args...) = mx.broadcast_plus(args...)
graph(::typeof(broadcast), ::typeof(*), args...) = mx.broadcast_mul(args...)
graph(::typeof(broadcast), ::typeof(-), args...) = mx.broadcast_sub(args...)
graph(::typeof(broadcast), ::typeof(*), args...) = mx.broadcast_mul(args...)
graph(::typeof(broadcast), ::typeof(/), args...) = mx.broadcast_div(args...)
graph(::typeof(broadcastto), xs, shape) = mx.broadcast_to(xs, shape = map(i -> i≤1?0:i, reverse(shape)))
# Old broadcasters
graph(::typeof(broadcast), ::typeof(exp), xs) = exp(xs)
@ -47,6 +48,8 @@ graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
graph(::typeof(sum), xs::Tuple) = reduce((a, b) -> graph(broadcast, +, a, b), xs)
a::mx.SymbolicNode b::mx.SymbolicNode = mx.broadcast_mul(a, b)
graph(::Input, x) = x
struct AlterParam