diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index c71b0be2..821247fb 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -32,23 +32,23 @@ graph(::typeof(vcat), a...) = graph(cat, 1, a...) graph(::Input, x) = x -# graph(vars, c::Conv, x) = -# mx.Convolution(data = x, -# kernel = c.size, -# num_filter = c.features, -# stride = c.stride) -# -# graph(vars, p::MaxPool, x) = -# mx.Pooling(data = x, -# pool_type = :max, -# kernel = p.size, -# stride = p.stride) -# -# graph(vars, d::Dense, x) = -# mx.FullyConnected(data = x, -# num_hidden = size(d.W.x, 1), -# weight = graph(vars, d.W), -# bias = graph(vars, d.b)) +# TODO: use actual params + +graph(ctx::Context, d::Affine, x) = + mx.FullyConnected(data = x, + num_hidden = size(d.W.x, 2)) + +graph(ctx::Context, c::Conv2D, x) = + mx.Convolution(data = x, + kernel = size(c.filter, 1, 2), + num_filter = size(c.filter, 4), + stride = c.stride) + +graph(ctx::Context, p::MaxPool, x) = + mx.Pooling(data = x, + pool_type = :max, + kernel = p.size, + stride = p.stride) function register(ctx::Context, node::mx.SymbolicNode) ctx[:stacks][nodename(node)] = stack(ctx)