conv/affine hacks
This commit is contained in:
parent
96c0e76b92
commit
12cde694b3
@ -32,23 +32,23 @@ graph(::typeof(vcat), a...) = graph(cat, 1, a...)
|
||||
|
||||
graph(::Input, x) = x
|
||||
|
||||
# graph(vars, c::Conv, x) =
|
||||
# mx.Convolution(data = x,
|
||||
# kernel = c.size,
|
||||
# num_filter = c.features,
|
||||
# stride = c.stride)
|
||||
#
|
||||
# graph(vars, p::MaxPool, x) =
|
||||
# mx.Pooling(data = x,
|
||||
# pool_type = :max,
|
||||
# kernel = p.size,
|
||||
# stride = p.stride)
|
||||
#
|
||||
# graph(vars, d::Dense, x) =
|
||||
# mx.FullyConnected(data = x,
|
||||
# num_hidden = size(d.W.x, 1),
|
||||
# weight = graph(vars, d.W),
|
||||
# bias = graph(vars, d.b))
|
||||
# TODO: use actual params
|
||||
|
||||
graph(ctx::Context, d::Affine, x) =
|
||||
mx.FullyConnected(data = x,
|
||||
num_hidden = size(d.W.x, 2))
|
||||
|
||||
graph(ctx::Context, c::Conv2D, x) =
|
||||
mx.Convolution(data = x,
|
||||
kernel = size(c.filter, 1, 2),
|
||||
num_filter = size(c.filter, 4),
|
||||
stride = c.stride)
|
||||
|
||||
graph(ctx::Context, p::MaxPool, x) =
|
||||
mx.Pooling(data = x,
|
||||
pool_type = :max,
|
||||
kernel = p.size,
|
||||
stride = p.stride)
|
||||
|
||||
function register(ctx::Context, node::mx.SymbolicNode)
|
||||
ctx[:stacks][nodename(node)] = stack(ctx)
|
||||
|
Loading…
Reference in New Issue
Block a user