fix mxnet
This commit is contained in:
parent
7a2a72a74a
commit
c7f8d86f9e
@ -46,6 +46,15 @@ graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
|
|||||||
|
|
||||||
graph(::Input, x) = x
|
graph(::Input, x) = x
|
||||||
|
|
||||||
|
struct AlterParam
|
||||||
|
param
|
||||||
|
load
|
||||||
|
store
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.size(p::AlterParam) = size(p.load(p.param.x))
|
||||||
|
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
|
||||||
|
|
||||||
graph(ctx::Context, d::Affine, x) =
|
graph(ctx::Context, d::Affine, x) =
|
||||||
!ctx[:feedforward] ? invoke(graph, Tuple{Context, Any, typeof(x)}, ctx, d, x) :
|
!ctx[:feedforward] ? invoke(graph, Tuple{Context, Any, typeof(x)}, ctx, d, x) :
|
||||||
register(ctx,
|
register(ctx,
|
||||||
@ -74,19 +83,16 @@ end
|
|||||||
|
|
||||||
register(ctx::Context, node) = node
|
register(ctx::Context, node) = node
|
||||||
|
|
||||||
function var(ctx::Context, p)
|
function var(ctx::Context, p::Union{Flux.Param{<:AArray},AArray,AlterParam})
|
||||||
id = gensym()
|
id = gensym()
|
||||||
ctx[:params][id] = p
|
ctx[:params][id] = p
|
||||||
return mx.Variable(id)
|
return mx.Variable(id)
|
||||||
end
|
end
|
||||||
|
|
||||||
graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}}) = var(ctx, p.value)
|
var(ctx::Context, x) = x
|
||||||
|
|
||||||
graph{T<:AArray}(ctx::Context, p::Constant{T}) = var(ctx, p.value)
|
|
||||||
|
|
||||||
graph(ctx::Context, p::Constant) = p.value
|
|
||||||
|
|
||||||
function graph(ctx::Context, model, args...)
|
function graph(ctx::Context, model, args...)
|
||||||
|
args = var.(ctx, args)
|
||||||
g = Flux.graph(model)
|
g = Flux.graph(model)
|
||||||
g == nothing && return register(ctx, @icatch ctx graph(model, args...))
|
g == nothing && return register(ctx, @icatch ctx graph(model, args...))
|
||||||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||||
@ -96,7 +102,7 @@ end
|
|||||||
graph′(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
|
graph′(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
|
||||||
|
|
||||||
function tograph(model, args...; feedforward = false)
|
function tograph(model, args...; feedforward = false)
|
||||||
ctx = Context(mux(iline, ilambda, iargs, ituple, graph′),
|
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, graph′),
|
||||||
params = Dict(), stacks = Dict(),
|
params = Dict(), stacks = Dict(),
|
||||||
feedforward = feedforward)
|
feedforward = feedforward)
|
||||||
out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)
|
out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)
|
||||||
|
@ -1,14 +1,5 @@
|
|||||||
using Flux: collectt, shapecheckt
|
using Flux: collectt, shapecheckt
|
||||||
|
|
||||||
struct AlterParam
|
|
||||||
param
|
|
||||||
load
|
|
||||||
store
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.size(p::AlterParam) = size(p.load(p.param.x))
|
|
||||||
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
|
|
||||||
|
|
||||||
function copyargs!(as, bs)
|
function copyargs!(as, bs)
|
||||||
for id in intersect(keys(as), keys(bs))
|
for id in intersect(keys(as), keys(bs))
|
||||||
copy!(as[id], bs[id])
|
copy!(as[id], bs[id])
|
||||||
|
Loading…
Reference in New Issue
Block a user