build graphs lazily

This commit is contained in:
Mike J Innes 2017-03-30 18:14:08 +01:00
parent acbc4ea071
commit 5df56b6073
2 changed files with 15 additions and 15 deletions

View File

@ -10,7 +10,7 @@ using Base: @get!
using DataFlow: Constant, constant using DataFlow: Constant, constant
using DataFlow.Interpreter using DataFlow.Interpreter
using DataFlow.Interpreter: Exception, totrace using DataFlow.Interpreter: Exception, totrace
using Flux: imap using Flux: imap, mapt
# TODO: implement Julia's type promotion rules # TODO: implement Julia's type promotion rules
@ -84,8 +84,8 @@ function tograph(model, args...; feedforward = false)
ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph), ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph),
params = Dict(), stacks = Dict(), params = Dict(), stacks = Dict(),
feedforward = feedforward) feedforward = feedforward)
out = @ithrow graph(ctx, model, args...) out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)
return Graph(out, ctx[:params], ctx[:stacks]) return Graph(args, out, ctx[:params], ctx[:stacks])
end end
# Error Handling # Error Handling

View File

@ -16,6 +16,7 @@ function copyargs!(as, bs)
end end
struct Graph struct Graph
input
output output
params::Dict{Symbol,Any} params::Dict{Symbol,Any}
stacks::Dict{Any,Any} stacks::Dict{Any,Any}
@ -48,8 +49,8 @@ mxungroup(x, outs) = copy(shift!(outs))
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x) mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
function executor(graph::Graph, input) function executor(graph::Graph, input)
args = merge(mxparams(graph), Dict(:input => MXArray(input))) args = merge(mxparams(graph), Dict(graph.input[1] => MXArray(input)))
grads = merge(mxparams(graph), Dict(:input => MXArray(input))) grads = merge(mxparams(graph), Dict(graph.input[1] => MXArray(input)))
exec = mx.bind(mxgroup(graph.output), exec = mx.bind(mxgroup(graph.output),
args = ndparams(args), args = ndparams(args),
args_grad = ndparams(grads), args_grad = ndparams(grads),
@ -60,15 +61,15 @@ function executor(graph::Graph, input)
end end
function (exec::Exec)(input) function (exec::Exec)(input)
copy!(exec.args[:input], input) copy!(exec.args[exec.graph.input[1]], input)
mx.forward(exec.exec, is_train = true) mx.forward(exec.exec, is_train = true)
mxungroup(exec.graph.output, copy(exec.outs)) mxungroup(exec.graph.output, copy(exec.outs))
end end
function Flux.back!(exec::Exec, Δ) function Flux.back!(exec::Exec, Δ)
exec.grads[:input][:] = 0 exec.grads[exec.graph.input[1]][:] = 0
mx.backward(exec.exec, MXArray(Δ).data) mx.backward(exec.exec, MXArray(Δ).data)
copy(exec.grads[:input]) copy(exec.grads[exec.graph.input[1]])
end end
function Flux.update!(exec::Exec, η) function Flux.update!(exec::Exec, η)
@ -86,22 +87,21 @@ end
mutable struct Model <: Flux.Model mutable struct Model <: Flux.Model
model::Any model::Any
graph::Graph
execs::Dict{Tuple,Exec} execs::Dict{Tuple,Exec}
graph::Graph
last::Exec last::Exec
Model(model, graph, execs) = new(model, graph, execs) Model(model) = new(model, Dict())
end end
function mxnet(model) mxnet(model) = Model(model)
graph = tograph(model, mx.Variable(:input))
Model(model, graph, Dict())
end
import Base: @get! import Base: @get!
executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input)) executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input))
function (m::Model)(x) function (m::Model)(x)
!isdefined(m, :graph) &&
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), input)))
@mxerr m.graph.stacks runrawbatched(x) do x @mxerr m.graph.stacks runrawbatched(x) do x
m.last = exec = executor(m, size(x)) m.last = exec = executor(m, size(x))
exec(x) exec(x)
@ -134,7 +134,7 @@ end
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu()) function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
model = rewrite_softmax(model, label) model = rewrite_softmax(model, label)
graph = tograph(model, mx.Variable(input), feedforward=true) graph = tograph(model, input, feedforward=true)
ff = mx.FeedForward(graph.output, context = context) ff = mx.FeedForward(graph.output, context = context)
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph))) isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph)))
return ff return ff