conv/affine hacks
This commit is contained in:
parent
96c0e76b92
commit
12cde694b3
@ -32,23 +32,23 @@ graph(::typeof(vcat), a...) = graph(cat, 1, a...)
|
|||||||
|
|
||||||
graph(::Input, x) = x
|
graph(::Input, x) = x
|
||||||
|
|
||||||
# graph(vars, c::Conv, x) =
|
# TODO: use actual params
|
||||||
# mx.Convolution(data = x,
|
|
||||||
# kernel = c.size,
|
graph(ctx::Context, d::Affine, x) =
|
||||||
# num_filter = c.features,
|
mx.FullyConnected(data = x,
|
||||||
# stride = c.stride)
|
num_hidden = size(d.W.x, 2))
|
||||||
#
|
|
||||||
# graph(vars, p::MaxPool, x) =
|
graph(ctx::Context, c::Conv2D, x) =
|
||||||
# mx.Pooling(data = x,
|
mx.Convolution(data = x,
|
||||||
# pool_type = :max,
|
kernel = size(c.filter, 1, 2),
|
||||||
# kernel = p.size,
|
num_filter = size(c.filter, 4),
|
||||||
# stride = p.stride)
|
stride = c.stride)
|
||||||
#
|
|
||||||
# graph(vars, d::Dense, x) =
|
graph(ctx::Context, p::MaxPool, x) =
|
||||||
# mx.FullyConnected(data = x,
|
mx.Pooling(data = x,
|
||||||
# num_hidden = size(d.W.x, 1),
|
pool_type = :max,
|
||||||
# weight = graph(vars, d.W),
|
kernel = p.size,
|
||||||
# bias = graph(vars, d.b))
|
stride = p.stride)
|
||||||
|
|
||||||
function register(ctx::Context, node::mx.SymbolicNode)
|
function register(ctx::Context, node::mx.SymbolicNode)
|
||||||
ctx[:stacks][nodename(node)] = stack(ctx)
|
ctx[:stacks][nodename(node)] = stack(ctx)
|
||||||
|
Loading…
Reference in New Issue
Block a user