build graphs lazily
This commit is contained in:
parent
acbc4ea071
commit
5df56b6073
@ -10,7 +10,7 @@ using Base: @get!
|
||||
using DataFlow: Constant, constant
|
||||
using DataFlow.Interpreter
|
||||
using DataFlow.Interpreter: Exception, totrace
|
||||
using Flux: imap
|
||||
using Flux: imap, mapt
|
||||
|
||||
# TODO: implement Julia's type promotion rules
|
||||
|
||||
@ -84,8 +84,8 @@ function tograph(model, args...; feedforward = false)
|
||||
ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph′),
|
||||
params = Dict(), stacks = Dict(),
|
||||
feedforward = feedforward)
|
||||
out = @ithrow graph(ctx, model, args...)
|
||||
return Graph(out, ctx[:params], ctx[:stacks])
|
||||
out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)
|
||||
return Graph(args, out, ctx[:params], ctx[:stacks])
|
||||
end
|
||||
|
||||
# Error Handling
|
||||
|
@ -16,6 +16,7 @@ function copyargs!(as, bs)
|
||||
end
|
||||
|
||||
struct Graph
|
||||
input
|
||||
output
|
||||
params::Dict{Symbol,Any}
|
||||
stacks::Dict{Any,Any}
|
||||
@ -48,8 +49,8 @@ mxungroup(x, outs) = copy(shift!(outs))
|
||||
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
||||
|
||||
function executor(graph::Graph, input)
|
||||
args = merge(mxparams(graph), Dict(:input => MXArray(input)))
|
||||
grads = merge(mxparams(graph), Dict(:input => MXArray(input)))
|
||||
args = merge(mxparams(graph), Dict(graph.input[1] => MXArray(input)))
|
||||
grads = merge(mxparams(graph), Dict(graph.input[1] => MXArray(input)))
|
||||
exec = mx.bind(mxgroup(graph.output),
|
||||
args = ndparams(args),
|
||||
args_grad = ndparams(grads),
|
||||
@ -60,15 +61,15 @@ function executor(graph::Graph, input)
|
||||
end
|
||||
|
||||
function (exec::Exec)(input)
|
||||
copy!(exec.args[:input], input)
|
||||
copy!(exec.args[exec.graph.input[1]], input)
|
||||
mx.forward(exec.exec, is_train = true)
|
||||
mxungroup(exec.graph.output, copy(exec.outs))
|
||||
end
|
||||
|
||||
function Flux.back!(exec::Exec, Δ)
|
||||
exec.grads[:input][:] = 0
|
||||
exec.grads[exec.graph.input[1]][:] = 0
|
||||
mx.backward(exec.exec, MXArray(Δ).data)
|
||||
copy(exec.grads[:input])
|
||||
copy(exec.grads[exec.graph.input[1]])
|
||||
end
|
||||
|
||||
function Flux.update!(exec::Exec, η)
|
||||
@ -86,22 +87,21 @@ end
|
||||
|
||||
mutable struct Model <: Flux.Model
|
||||
model::Any
|
||||
graph::Graph
|
||||
execs::Dict{Tuple,Exec}
|
||||
graph::Graph
|
||||
last::Exec
|
||||
Model(model, graph, execs) = new(model, graph, execs)
|
||||
Model(model) = new(model, Dict())
|
||||
end
|
||||
|
||||
function mxnet(model)
|
||||
graph = tograph(model, mx.Variable(:input))
|
||||
Model(model, graph, Dict())
|
||||
end
|
||||
mxnet(model) = Model(model)
|
||||
|
||||
import Base: @get!
|
||||
|
||||
executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input))
|
||||
|
||||
function (m::Model)(x)
|
||||
!isdefined(m, :graph) &&
|
||||
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), input)))
|
||||
@mxerr m.graph.stacks runrawbatched(x) do x
|
||||
m.last = exec = executor(m, size(x))
|
||||
exec(x)
|
||||
@ -134,7 +134,7 @@ end
|
||||
|
||||
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
||||
model = rewrite_softmax(model, label)
|
||||
graph = tograph(model, mx.Variable(input), feedforward=true)
|
||||
graph = tograph(model, input, feedforward=true)
|
||||
ff = mx.FeedForward(graph.output, context = context)
|
||||
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph)))
|
||||
return ff
|
||||
|
Loading…
Reference in New Issue
Block a user