diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index d5213469..d557f68b 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -10,7 +10,7 @@ using Base: @get! using DataFlow: Constant, constant using DataFlow.Interpreter using DataFlow.Interpreter: Exception, totrace -using Flux: imap +using Flux: imap, mapt # TODO: implement Julia's type promotion rules @@ -84,8 +84,8 @@ function tograph(model, args...; feedforward = false) ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph′), params = Dict(), stacks = Dict(), feedforward = feedforward) - out = @ithrow graph(ctx, model, args...) - return Graph(out, ctx[:params], ctx[:stacks]) + out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...) + return Graph(args, out, ctx[:params], ctx[:stacks]) end # Error Handling diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 62393d09..0f7df34f 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -16,6 +16,7 @@ function copyargs!(as, bs) end struct Graph + input output params::Dict{Symbol,Any} stacks::Dict{Any,Any} @@ -48,8 +49,8 @@ mxungroup(x, outs) = copy(shift!(outs)) mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x) function executor(graph::Graph, input) - args = merge(mxparams(graph), Dict(:input => MXArray(input))) - grads = merge(mxparams(graph), Dict(:input => MXArray(input))) + args = merge(mxparams(graph), Dict(graph.input[1] => MXArray(input))) + grads = merge(mxparams(graph), Dict(graph.input[1] => MXArray(input))) exec = mx.bind(mxgroup(graph.output), args = ndparams(args), args_grad = ndparams(grads), @@ -60,15 +61,15 @@ function executor(graph::Graph, input) end function (exec::Exec)(input) - copy!(exec.args[:input], input) + copy!(exec.args[exec.graph.input[1]], input) mx.forward(exec.exec, is_train = true) mxungroup(exec.graph.output, copy(exec.outs)) end function Flux.back!(exec::Exec, Δ) - exec.grads[:input][:] = 0 + exec.grads[exec.graph.input[1]][:] = 0 mx.backward(exec.exec, MXArray(Δ).data) - copy(exec.grads[:input]) + copy(exec.grads[exec.graph.input[1]]) end function Flux.update!(exec::Exec, η) @@ -86,22 +87,21 @@ end mutable struct Model <: Flux.Model model::Any - graph::Graph execs::Dict{Tuple,Exec} + graph::Graph last::Exec - Model(model, graph, execs) = new(model, graph, execs) + Model(model) = new(model, Dict()) end -function mxnet(model) - graph = tograph(model, mx.Variable(:input)) - Model(model, graph, Dict()) -end +mxnet(model) = Model(model) import Base: @get! executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input)) function (m::Model)(x) + !isdefined(m, :graph) && + (m.graph = tograph(m.model, mapt(_ -> gensym("input"), input))) @mxerr m.graph.stacks runrawbatched(x) do x m.last = exec = executor(m, size(x)) exec(x) @@ -134,7 +134,7 @@ end function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu()) model = rewrite_softmax(model, label) - graph = tograph(model, mx.Variable(input), feedforward=true) + graph = tograph(model, input, feedforward=true) ff = mx.FeedForward(graph.output, context = context) isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph))) return ff