get this somewhat working

This commit is contained in:
Mike J Innes 2017-02-21 12:58:31 +00:00
parent 417a70713b
commit 8d63bf8053
3 changed files with 13 additions and 7 deletions

View File

@ -19,6 +19,7 @@ node(x::mx.SymbolicNode) = x
graph(::typeof(tuple), args...) = (args...,)
graph(::typeof(+), args...) = mx.broadcast_plus(args...)
graph(::typeof(*), x, W) = mx.dot(transpose(W), x) # Adjustments for batching
graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
graph(::typeof(relu), x) = mx.Activation(data = x, act_type = :relu)
graph(::typeof(tanh), x) = mx.Activation(data = x, act_type = :tanh)
@ -32,12 +33,13 @@ graph(::typeof(vcat), a...) = graph(cat, 1, a...)
graph(::Input, x) = x
# TODO: use actual params
graph(ctx::Context, d::Affine, x) =
mx.FullyConnected(data = x,
num_hidden = size(d.W.x, 2))
num_hidden = size(d.W.x, 2),
weight = var(ctx, d.W),
bias = var(ctx, d.b, size(d.b, 2)))
# TODO: use actual params}
graph(ctx::Context, c::Conv2D, x) =
mx.Convolution(data = x,
kernel = size(c.filter, 1, 2),
@ -57,9 +59,9 @@ end
register(ctx::Context, node) = node
function var(ctx::Context, p::Flux.Param)
function var(ctx::Context, p::Flux.Param, size = nothing)
id = gensym()
ctx[:params][id] = p.x
ctx[:params][id] = size == nothing ? p.x : reshape(p.x, size...)
return mx.Variable(id)
end

View File

@ -47,13 +47,17 @@ function mxnet(model::Flux.Model, input)
return model
end
# MNet batches on last dimension
rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1))
rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...))
function runmodel(model::Model, input)
copy!(model.exec.arg_dict[:input], input)
mx.forward(model.exec, is_train = true)
copy(model.exec.outputs[1])
end
(m::Model)(x::Batch) = rebatch(runmodel(m, rawbatch(x)))
(m::Model)(x::Batch) = rebatch(rebatch_first(runmodel(m, rebatch_last(rawbatch(x)))))
(m::Model)(x) = first(m(batchone(x)))

View File

@ -3,7 +3,7 @@ d = Affine(20, 10)
# MXNet
@mxonly let dm = mxnet(d, (1, 20))
@mxonly let dm = mxnet(d, (20, 1))
@test d(xs) dm(xs)
end