fix param interpretation

This commit is contained in:
Mike J Innes 2016-12-13 15:46:34 +00:00
parent a63cd826c2
commit 1b22d55401
1 changed files with 3 additions and 1 deletions

View File

@ -42,6 +42,8 @@ interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
ctx[:params][p.value] :
(ctx[:params][p.value] = Variable(p.value.x))
interp(ctx, p::Constant) = p.value
function interp(ctx, model, args...)
g = Flux.graph(model)
g == nothing && return graph(model, interpret(ctx, args)...)
@ -50,7 +52,7 @@ function interp(ctx, model, args...)
end
function tograph(model, args...)
ctx = Context(interplambda(interptuple(interpmap(interpconst(interp)))), params = ObjectIdDict())
ctx = Context(interplambda(interptuple(interpmap(interp))), params = ObjectIdDict())
out = interp(ctx, model, map(constant, args)...)
return ctx[:params], out
end