tweak constants approach
This commit is contained in:
parent
1424b75e78
commit
2aa8dfc208
|
@ -1,5 +1,6 @@
|
|||
using Base: @get!
|
||||
using DataFlow: Constant, constant, Context, interpret, Split
|
||||
using DataFlow: Constant, constant, Context, interpret, Split, interptuple, interplambda, interpconst
|
||||
using Flux: interpmap
|
||||
using TensorFlow: RawTensor
|
||||
|
||||
# TODO: implement Julia's type promotion rules
|
||||
|
@ -36,8 +37,6 @@ graph(op::Op, xs...) = op.f(xs...)
|
|||
interp(ctx, c::Conv2D, x) =
|
||||
nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
|
||||
|
||||
interp(ctx, c::Constant) = node(c.value)
|
||||
|
||||
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
|
||||
haskey(ctx[:params], p.value) ?
|
||||
ctx[:params][p.value] :
|
||||
|
@ -51,7 +50,7 @@ function interp(ctx, model, args...)
|
|||
end
|
||||
|
||||
function tograph(model, args...)
|
||||
ctx = Context(interp, params = ObjectIdDict())
|
||||
ctx = Context(interplambda(interptuple(interpmap(interpconst(interp)))), params = ObjectIdDict())
|
||||
out = interp(ctx, model, map(constant, args)...)
|
||||
return ctx[:params], out
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue