fix mxnet

This commit is contained in:
Mike J Innes 2017-05-22 18:24:14 +01:00
parent 7a2a72a74a
commit c7f8d86f9e
2 changed files with 13 additions and 16 deletions

View File

@ -46,6 +46,15 @@ graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
graph(::Input, x) = x graph(::Input, x) = x
struct AlterParam
param
load
store
end
Base.size(p::AlterParam) = size(p.load(p.param.x))
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
graph(ctx::Context, d::Affine, x) = graph(ctx::Context, d::Affine, x) =
!ctx[:feedforward] ? invoke(graph, Tuple{Context, Any, typeof(x)}, ctx, d, x) : !ctx[:feedforward] ? invoke(graph, Tuple{Context, Any, typeof(x)}, ctx, d, x) :
register(ctx, register(ctx,
@ -74,19 +83,16 @@ end
register(ctx::Context, node) = node register(ctx::Context, node) = node
function var(ctx::Context, p) function var(ctx::Context, p::Union{Flux.Param{<:AArray},AArray,AlterParam})
id = gensym() id = gensym()
ctx[:params][id] = p ctx[:params][id] = p
return mx.Variable(id) return mx.Variable(id)
end end
graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}}) = var(ctx, p.value) var(ctx::Context, x) = x
graph{T<:AArray}(ctx::Context, p::Constant{T}) = var(ctx, p.value)
graph(ctx::Context, p::Constant) = p.value
function graph(ctx::Context, model, args...) function graph(ctx::Context, model, args...)
args = var.(ctx, args)
g = Flux.graph(model) g = Flux.graph(model)
g == nothing && return register(ctx, @icatch ctx graph(model, args...)) g == nothing && return register(ctx, @icatch ctx graph(model, args...))
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.") DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
@ -96,7 +102,7 @@ end
graph(ctx::Context, args...) = @icatch ctx graph(ctx, args...) graph(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
function tograph(model, args...; feedforward = false) function tograph(model, args...; feedforward = false)
ctx = Context(mux(iline, ilambda, iargs, ituple, graph), ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, graph),
params = Dict(), stacks = Dict(), params = Dict(), stacks = Dict(),
feedforward = feedforward) feedforward = feedforward)
out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...) out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)

View File

@ -1,14 +1,5 @@
using Flux: collectt, shapecheckt using Flux: collectt, shapecheckt
struct AlterParam
param
load
store
end
Base.size(p::AlterParam) = size(p.load(p.param.x))
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
function copyargs!(as, bs) function copyargs!(as, bs)
for id in intersect(keys(as), keys(bs)) for id in intersect(keys(as), keys(bs))
copy!(as[id], bs[id]) copy!(as[id], bs[id])