tweak constants approach
This commit is contained in:
parent
1424b75e78
commit
2aa8dfc208
@ -1,5 +1,6 @@
|
|||||||
using Base: @get!
|
using Base: @get!
|
||||||
using DataFlow: Constant, constant, Context, interpret, Split
|
using DataFlow: Constant, constant, Context, interpret, Split, interptuple, interplambda, interpconst
|
||||||
|
using Flux: interpmap
|
||||||
using TensorFlow: RawTensor
|
using TensorFlow: RawTensor
|
||||||
|
|
||||||
# TODO: implement Julia's type promotion rules
|
# TODO: implement Julia's type promotion rules
|
||||||
@ -36,8 +37,6 @@ graph(op::Op, xs...) = op.f(xs...)
|
|||||||
interp(ctx, c::Conv2D, x) =
|
interp(ctx, c::Conv2D, x) =
|
||||||
nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
|
nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
|
||||||
|
|
||||||
interp(ctx, c::Constant) = node(c.value)
|
|
||||||
|
|
||||||
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
|
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
|
||||||
haskey(ctx[:params], p.value) ?
|
haskey(ctx[:params], p.value) ?
|
||||||
ctx[:params][p.value] :
|
ctx[:params][p.value] :
|
||||||
@ -51,7 +50,7 @@ function interp(ctx, model, args...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function tograph(model, args...)
|
function tograph(model, args...)
|
||||||
ctx = Context(interp, params = ObjectIdDict())
|
ctx = Context(interplambda(interptuple(interpmap(interpconst(interp)))), params = ObjectIdDict())
|
||||||
out = interp(ctx, model, map(constant, args)...)
|
out = interp(ctx, model, map(constant, args)...)
|
||||||
return ctx[:params], out
|
return ctx[:params], out
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user