From edb1099cec9c40d405f5aae8a0da19507508a65f Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 2 Jun 2017 14:42:15 +0100 Subject: [PATCH] more mxnet ops --- src/backend/mxnet/graph.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 18372f68..e3f26494 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -10,7 +10,7 @@ using Base: @get! using DataFlow: Constant, constant using DataFlow.Interpreter using DataFlow.Interpreter: Exception, totrace -using Flux: mapt, broadcastto +import Flux: mapt, broadcastto, ∘ # TODO: implement Julia's type promotion rules @@ -28,8 +28,9 @@ graph(::typeof(hcat), xs...) = mx.concat(xs..., dim = 2-1) graph(::typeof(vec), xs) = reshape(xs, shape = (-1,)) graph(::typeof(broadcast), ::typeof(+), args...) = mx.broadcast_plus(args...) -graph(::typeof(broadcast), ::typeof(*), args...) = mx.broadcast_mul(args...) graph(::typeof(broadcast), ::typeof(-), args...) = mx.broadcast_sub(args...) +graph(::typeof(broadcast), ::typeof(*), args...) = mx.broadcast_mul(args...) +graph(::typeof(broadcast), ::typeof(/), args...) = mx.broadcast_div(args...) graph(::typeof(broadcastto), xs, shape) = mx.broadcast_to(xs, shape = map(i -> i≤1?0:i, reverse(shape))) # Old broadcasters graph(::typeof(broadcast), ::typeof(exp), xs) = exp(xs) @@ -47,6 +48,8 @@ graph(::typeof(map), f, xss::Tuple...) = map(f, xss...) graph(::typeof(getindex), t::Tuple, n::Integer) = t[n] graph(::typeof(sum), xs::Tuple) = reduce((a, b) -> graph(broadcast, +, a, b), xs) +a::mx.SymbolicNode ∘ b::mx.SymbolicNode = mx.broadcast_mul(a, b) + graph(::Input, x) = x struct AlterParam