diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index fe9ccd07..313519a0 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -46,6 +46,15 @@ graph(::typeof(map), f, xss::Tuple...) = map(f, xss...) graph(::Input, x) = x +struct AlterParam + param + load + store +end + +Base.size(p::AlterParam) = size(p.load(p.param.x)) +Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x)) + graph(ctx::Context, d::Affine, x) = !ctx[:feedforward] ? invoke(graph, Tuple{Context, Any, typeof(x)}, ctx, d, x) : register(ctx, @@ -74,19 +83,16 @@ end register(ctx::Context, node) = node -function var(ctx::Context, p) +function var(ctx::Context, p::Union{Flux.Param{<:AArray},AArray,AlterParam}) id = gensym() ctx[:params][id] = p return mx.Variable(id) end -graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}}) = var(ctx, p.value) - -graph{T<:AArray}(ctx::Context, p::Constant{T}) = var(ctx, p.value) - -graph(ctx::Context, p::Constant) = p.value +var(ctx::Context, x) = x function graph(ctx::Context, model, args...) + args = var.(ctx, args) g = Flux.graph(model) g == nothing && return register(ctx, @icatch ctx graph(model, args...)) DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.") @@ -96,7 +102,7 @@ end graph′(ctx::Context, args...) = @icatch ctx graph(ctx, args...) function tograph(model, args...; feedforward = false) - ctx = Context(mux(iline, ilambda, iargs, ituple, graph′), + ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, graph′), params = Dict(), stacks = Dict(), feedforward = feedforward) out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 507c01c2..7d3c7895 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -1,14 +1,5 @@ using Flux: collectt, shapecheckt -struct AlterParam - param - load - store -end - -Base.size(p::AlterParam) = size(p.load(p.param.x)) -Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x)) - function copyargs!(as, bs) for id in intersect(keys(as), keys(bs)) copy!(as[id], bs[id])