get this somewhat working
This commit is contained in:
parent
417a70713b
commit
8d63bf8053
|
@ -19,6 +19,7 @@ node(x::mx.SymbolicNode) = x
|
|||
|
||||
graph(::typeof(tuple), args...) = (args...,)
|
||||
graph(::typeof(+), args...) = mx.broadcast_plus(args...)
|
||||
graph(::typeof(*), x, W) = mx.dot(transpose(W), x) # Adjustments for batching
|
||||
graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
||||
graph(::typeof(relu), x) = mx.Activation(data = x, act_type = :relu)
|
||||
graph(::typeof(tanh), x) = mx.Activation(data = x, act_type = :tanh)
|
||||
|
@ -32,12 +33,13 @@ graph(::typeof(vcat), a...) = graph(cat, 1, a...)
|
|||
|
||||
graph(::Input, x) = x
|
||||
|
||||
# TODO: use actual params
|
||||
|
||||
graph(ctx::Context, d::Affine, x) =
|
||||
mx.FullyConnected(data = x,
|
||||
num_hidden = size(d.W.x, 2))
|
||||
num_hidden = size(d.W.x, 2),
|
||||
weight = var(ctx, d.W),
|
||||
bias = var(ctx, d.b, size(d.b, 2)))
|
||||
|
||||
# TODO: use actual params}
|
||||
graph(ctx::Context, c::Conv2D, x) =
|
||||
mx.Convolution(data = x,
|
||||
kernel = size(c.filter, 1, 2),
|
||||
|
@ -57,9 +59,9 @@ end
|
|||
|
||||
register(ctx::Context, node) = node
|
||||
|
||||
function var(ctx::Context, p::Flux.Param)
|
||||
function var(ctx::Context, p::Flux.Param, size = nothing)
|
||||
id = gensym()
|
||||
ctx[:params][id] = p.x
|
||||
ctx[:params][id] = size == nothing ? p.x : reshape(p.x, size...)
|
||||
return mx.Variable(id)
|
||||
end
|
||||
|
||||
|
|
|
@ -47,13 +47,17 @@ function mxnet(model::Flux.Model, input)
|
|||
return model
|
||||
end
|
||||
|
||||
# MNet batches on last dimension
|
||||
rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1))
|
||||
rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...))
|
||||
|
||||
function runmodel(model::Model, input)
|
||||
copy!(model.exec.arg_dict[:input], input)
|
||||
mx.forward(model.exec, is_train = true)
|
||||
copy(model.exec.outputs[1])
|
||||
end
|
||||
|
||||
(m::Model)(x::Batch) = rebatch(runmodel(m, rawbatch(x)))
|
||||
(m::Model)(x::Batch) = rebatch(rebatch_first(runmodel(m, rebatch_last(rawbatch(x)))))
|
||||
|
||||
(m::Model)(x) = first(m(batchone(x)))
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ d = Affine(20, 10)
|
|||
|
||||
# MXNet
|
||||
|
||||
@mxonly let dm = mxnet(d, (1, 20))
|
||||
@mxonly let dm = mxnet(d, (20, 1))
|
||||
@test d(xs) ≈ dm(xs)
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue