fix param interpretation
This commit is contained in:
parent
a63cd826c2
commit
1b22d55401
@ -42,6 +42,8 @@ interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
|
|||||||
ctx[:params][p.value] :
|
ctx[:params][p.value] :
|
||||||
(ctx[:params][p.value] = Variable(p.value.x))
|
(ctx[:params][p.value] = Variable(p.value.x))
|
||||||
|
|
||||||
|
interp(ctx, p::Constant) = p.value
|
||||||
|
|
||||||
function interp(ctx, model, args...)
|
function interp(ctx, model, args...)
|
||||||
g = Flux.graph(model)
|
g = Flux.graph(model)
|
||||||
g == nothing && return graph(model, interpret(ctx, args)...)
|
g == nothing && return graph(model, interpret(ctx, args)...)
|
||||||
@ -50,7 +52,7 @@ function interp(ctx, model, args...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function tograph(model, args...)
|
function tograph(model, args...)
|
||||||
ctx = Context(interplambda(interptuple(interpmap(interpconst(interp)))), params = ObjectIdDict())
|
ctx = Context(interplambda(interptuple(interpmap(interp))), params = ObjectIdDict())
|
||||||
out = interp(ctx, model, map(constant, args)...)
|
out = interp(ctx, model, map(constant, args)...)
|
||||||
return ctx[:params], out
|
return ctx[:params], out
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user