fix mxnet

This commit is contained in:
Mike J Innes 2017-05-22 18:24:14 +01:00
parent 7a2a72a74a
commit c7f8d86f9e
2 changed files with 13 additions and 16 deletions

View File

@ -46,6 +46,15 @@ graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
graph(::Input, x) = x
struct AlterParam
param
load
store
end
Base.size(p::AlterParam) = size(p.load(p.param.x))
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
graph(ctx::Context, d::Affine, x) =
!ctx[:feedforward] ? invoke(graph, Tuple{Context, Any, typeof(x)}, ctx, d, x) :
register(ctx,
@ -74,19 +83,16 @@ end
register(ctx::Context, node) = node
function var(ctx::Context, p)
function var(ctx::Context, p::Union{Flux.Param{<:AArray},AArray,AlterParam})
id = gensym()
ctx[:params][id] = p
return mx.Variable(id)
end
graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}}) = var(ctx, p.value)
graph{T<:AArray}(ctx::Context, p::Constant{T}) = var(ctx, p.value)
graph(ctx::Context, p::Constant) = p.value
var(ctx::Context, x) = x
function graph(ctx::Context, model, args...)
args = var.(ctx, args)
g = Flux.graph(model)
g == nothing && return register(ctx, @icatch ctx graph(model, args...))
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
@ -96,7 +102,7 @@ end
graph(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
function tograph(model, args...; feedforward = false)
ctx = Context(mux(iline, ilambda, iargs, ituple, graph),
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, graph),
params = Dict(), stacks = Dict(),
feedforward = feedforward)
out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)

View File

@ -1,14 +1,5 @@
using Flux: collectt, shapecheckt
struct AlterParam
param
load
store
end
Base.size(p::AlterParam) = size(p.load(p.param.x))
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
function copyargs!(as, bs)
for id in intersect(keys(as), keys(bs))
copy!(as[id], bs[id])