conv/affine hacks

This commit is contained in:
Mike J Innes 2017-02-21 08:52:40 +00:00
parent 96c0e76b92
commit 12cde694b3

View File

@ -32,23 +32,23 @@ graph(::typeof(vcat), a...) = graph(cat, 1, a...)
graph(::Input, x) = x
# graph(vars, c::Conv, x) =
# mx.Convolution(data = x,
# kernel = c.size,
# num_filter = c.features,
# stride = c.stride)
#
# graph(vars, p::MaxPool, x) =
# mx.Pooling(data = x,
# pool_type = :max,
# kernel = p.size,
# stride = p.stride)
#
# graph(vars, d::Dense, x) =
# mx.FullyConnected(data = x,
# num_hidden = size(d.W.x, 1),
# weight = graph(vars, d.W),
# bias = graph(vars, d.b))
# TODO: use actual params
graph(ctx::Context, d::Affine, x) =
mx.FullyConnected(data = x,
num_hidden = size(d.W.x, 2))
graph(ctx::Context, c::Conv2D, x) =
mx.Convolution(data = x,
kernel = size(c.filter, 1, 2),
num_filter = size(c.filter, 4),
stride = c.stride)
graph(ctx::Context, p::MaxPool, x) =
mx.Pooling(data = x,
pool_type = :max,
kernel = p.size,
stride = p.stride)
function register(ctx::Context, node::mx.SymbolicNode)
ctx[:stacks][nodename(node)] = stack(ctx)