fix param interpretation
This commit is contained in:
parent
a63cd826c2
commit
1b22d55401
|
@ -42,6 +42,8 @@ interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
|
|||
ctx[:params][p.value] :
|
||||
(ctx[:params][p.value] = Variable(p.value.x))
|
||||
|
||||
interp(ctx, p::Constant) = p.value
|
||||
|
||||
function interp(ctx, model, args...)
|
||||
g = Flux.graph(model)
|
||||
g == nothing && return graph(model, interpret(ctx, args)...)
|
||||
|
@ -50,7 +52,7 @@ function interp(ctx, model, args...)
|
|||
end
|
||||
|
||||
function tograph(model, args...)
|
||||
ctx = Context(interplambda(interptuple(interpmap(interpconst(interp)))), params = ObjectIdDict())
|
||||
ctx = Context(interplambda(interptuple(interpmap(interp))), params = ObjectIdDict())
|
||||
out = interp(ctx, model, map(constant, args)...)
|
||||
return ctx[:params], out
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue