fix mxnet
This commit is contained in:
parent
7a2a72a74a
commit
c7f8d86f9e
@ -46,6 +46,15 @@ graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
|
||||
|
||||
graph(::Input, x) = x
|
||||
|
||||
struct AlterParam
|
||||
param
|
||||
load
|
||||
store
|
||||
end
|
||||
|
||||
Base.size(p::AlterParam) = size(p.load(p.param.x))
|
||||
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
|
||||
|
||||
graph(ctx::Context, d::Affine, x) =
|
||||
!ctx[:feedforward] ? invoke(graph, Tuple{Context, Any, typeof(x)}, ctx, d, x) :
|
||||
register(ctx,
|
||||
@ -74,19 +83,16 @@ end
|
||||
|
||||
register(ctx::Context, node) = node
|
||||
|
||||
function var(ctx::Context, p)
|
||||
function var(ctx::Context, p::Union{Flux.Param{<:AArray},AArray,AlterParam})
|
||||
id = gensym()
|
||||
ctx[:params][id] = p
|
||||
return mx.Variable(id)
|
||||
end
|
||||
|
||||
graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}}) = var(ctx, p.value)
|
||||
|
||||
graph{T<:AArray}(ctx::Context, p::Constant{T}) = var(ctx, p.value)
|
||||
|
||||
graph(ctx::Context, p::Constant) = p.value
|
||||
var(ctx::Context, x) = x
|
||||
|
||||
function graph(ctx::Context, model, args...)
|
||||
args = var.(ctx, args)
|
||||
g = Flux.graph(model)
|
||||
g == nothing && return register(ctx, @icatch ctx graph(model, args...))
|
||||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||
@ -96,7 +102,7 @@ end
|
||||
graph′(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
|
||||
|
||||
function tograph(model, args...; feedforward = false)
|
||||
ctx = Context(mux(iline, ilambda, iargs, ituple, graph′),
|
||||
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, graph′),
|
||||
params = Dict(), stacks = Dict(),
|
||||
feedforward = feedforward)
|
||||
out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)
|
||||
|
@ -1,14 +1,5 @@
|
||||
using Flux: collectt, shapecheckt
|
||||
|
||||
struct AlterParam
|
||||
param
|
||||
load
|
||||
store
|
||||
end
|
||||
|
||||
Base.size(p::AlterParam) = size(p.load(p.param.x))
|
||||
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
|
||||
|
||||
function copyargs!(as, bs)
|
||||
for id in intersect(keys(as), keys(bs))
|
||||
copy!(as[id], bs[id])
|
||||
|
Loading…
Reference in New Issue
Block a user