build graphs lazily
This commit is contained in:
parent
acbc4ea071
commit
5df56b6073
@ -10,7 +10,7 @@ using Base: @get!
|
|||||||
using DataFlow: Constant, constant
|
using DataFlow: Constant, constant
|
||||||
using DataFlow.Interpreter
|
using DataFlow.Interpreter
|
||||||
using DataFlow.Interpreter: Exception, totrace
|
using DataFlow.Interpreter: Exception, totrace
|
||||||
using Flux: imap
|
using Flux: imap, mapt
|
||||||
|
|
||||||
# TODO: implement Julia's type promotion rules
|
# TODO: implement Julia's type promotion rules
|
||||||
|
|
||||||
@ -84,8 +84,8 @@ function tograph(model, args...; feedforward = false)
|
|||||||
ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph′),
|
ctx = Context(mux(iline, ilambda, imap, iargs, ituple, graph′),
|
||||||
params = Dict(), stacks = Dict(),
|
params = Dict(), stacks = Dict(),
|
||||||
feedforward = feedforward)
|
feedforward = feedforward)
|
||||||
out = @ithrow graph(ctx, model, args...)
|
out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)
|
||||||
return Graph(out, ctx[:params], ctx[:stacks])
|
return Graph(args, out, ctx[:params], ctx[:stacks])
|
||||||
end
|
end
|
||||||
|
|
||||||
# Error Handling
|
# Error Handling
|
||||||
|
@ -16,6 +16,7 @@ function copyargs!(as, bs)
|
|||||||
end
|
end
|
||||||
|
|
||||||
struct Graph
|
struct Graph
|
||||||
|
input
|
||||||
output
|
output
|
||||||
params::Dict{Symbol,Any}
|
params::Dict{Symbol,Any}
|
||||||
stacks::Dict{Any,Any}
|
stacks::Dict{Any,Any}
|
||||||
@ -48,8 +49,8 @@ mxungroup(x, outs) = copy(shift!(outs))
|
|||||||
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
||||||
|
|
||||||
function executor(graph::Graph, input)
|
function executor(graph::Graph, input)
|
||||||
args = merge(mxparams(graph), Dict(:input => MXArray(input)))
|
args = merge(mxparams(graph), Dict(graph.input[1] => MXArray(input)))
|
||||||
grads = merge(mxparams(graph), Dict(:input => MXArray(input)))
|
grads = merge(mxparams(graph), Dict(graph.input[1] => MXArray(input)))
|
||||||
exec = mx.bind(mxgroup(graph.output),
|
exec = mx.bind(mxgroup(graph.output),
|
||||||
args = ndparams(args),
|
args = ndparams(args),
|
||||||
args_grad = ndparams(grads),
|
args_grad = ndparams(grads),
|
||||||
@ -60,15 +61,15 @@ function executor(graph::Graph, input)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function (exec::Exec)(input)
|
function (exec::Exec)(input)
|
||||||
copy!(exec.args[:input], input)
|
copy!(exec.args[exec.graph.input[1]], input)
|
||||||
mx.forward(exec.exec, is_train = true)
|
mx.forward(exec.exec, is_train = true)
|
||||||
mxungroup(exec.graph.output, copy(exec.outs))
|
mxungroup(exec.graph.output, copy(exec.outs))
|
||||||
end
|
end
|
||||||
|
|
||||||
function Flux.back!(exec::Exec, Δ)
|
function Flux.back!(exec::Exec, Δ)
|
||||||
exec.grads[:input][:] = 0
|
exec.grads[exec.graph.input[1]][:] = 0
|
||||||
mx.backward(exec.exec, MXArray(Δ).data)
|
mx.backward(exec.exec, MXArray(Δ).data)
|
||||||
copy(exec.grads[:input])
|
copy(exec.grads[exec.graph.input[1]])
|
||||||
end
|
end
|
||||||
|
|
||||||
function Flux.update!(exec::Exec, η)
|
function Flux.update!(exec::Exec, η)
|
||||||
@ -86,22 +87,21 @@ end
|
|||||||
|
|
||||||
mutable struct Model <: Flux.Model
|
mutable struct Model <: Flux.Model
|
||||||
model::Any
|
model::Any
|
||||||
graph::Graph
|
|
||||||
execs::Dict{Tuple,Exec}
|
execs::Dict{Tuple,Exec}
|
||||||
|
graph::Graph
|
||||||
last::Exec
|
last::Exec
|
||||||
Model(model, graph, execs) = new(model, graph, execs)
|
Model(model) = new(model, Dict())
|
||||||
end
|
end
|
||||||
|
|
||||||
function mxnet(model)
|
mxnet(model) = Model(model)
|
||||||
graph = tograph(model, mx.Variable(:input))
|
|
||||||
Model(model, graph, Dict())
|
|
||||||
end
|
|
||||||
|
|
||||||
import Base: @get!
|
import Base: @get!
|
||||||
|
|
||||||
executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input))
|
executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input))
|
||||||
|
|
||||||
function (m::Model)(x)
|
function (m::Model)(x)
|
||||||
|
!isdefined(m, :graph) &&
|
||||||
|
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), input)))
|
||||||
@mxerr m.graph.stacks runrawbatched(x) do x
|
@mxerr m.graph.stacks runrawbatched(x) do x
|
||||||
m.last = exec = executor(m, size(x))
|
m.last = exec = executor(m, size(x))
|
||||||
exec(x)
|
exec(x)
|
||||||
@ -134,7 +134,7 @@ end
|
|||||||
|
|
||||||
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
||||||
model = rewrite_softmax(model, label)
|
model = rewrite_softmax(model, label)
|
||||||
graph = tograph(model, mx.Variable(input), feedforward=true)
|
graph = tograph(model, input, feedforward=true)
|
||||||
ff = mx.FeedForward(graph.output, context = context)
|
ff = mx.FeedForward(graph.output, context = context)
|
||||||
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph)))
|
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph)))
|
||||||
return ff
|
return ff
|
||||||
|
Loading…
Reference in New Issue
Block a user