tweak constants approach

This commit is contained in:
Mike J Innes 2016-11-17 11:28:24 +00:00
parent 1424b75e78
commit 2aa8dfc208
1 changed files with 3 additions and 4 deletions

View File

@ -1,5 +1,6 @@
using Base: @get!
using DataFlow: Constant, constant, Context, interpret, Split
using DataFlow: Constant, constant, Context, interpret, Split, interptuple, interplambda, interpconst
using Flux: interpmap
using TensorFlow: RawTensor
# TODO: implement Julia's type promotion rules
@ -36,8 +37,6 @@ graph(op::Op, xs...) = op.f(xs...)
interp(ctx, c::Conv2D, x) =
nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
interp(ctx, c::Constant) = node(c.value)
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
haskey(ctx[:params], p.value) ?
ctx[:params][p.value] :
@ -51,7 +50,7 @@ function interp(ctx, model, args...)
end
function tograph(model, args...)
ctx = Context(interp, params = ObjectIdDict())
ctx = Context(interplambda(interptuple(interpmap(interpconst(interp)))), params = ObjectIdDict())
out = interp(ctx, model, map(constant, args)...)
return ctx[:params], out
end