split interpret / interpv

This commit is contained in:
Mike J Innes 2016-12-21 13:05:18 +00:00
parent 6acfcd913e
commit 353f156354
2 changed files with 7 additions and 7 deletions

View File

@ -1,6 +1,6 @@
using Base: @get!
using DataFlow: Constant, constant, Context, interpret, Split, interptuple,
interplambda, interpconst, interpline, stack
interpv, interplambda, interpconst, interpline, stack
using Flux: interpmap
using TensorFlow: RawTensor
@ -38,13 +38,13 @@ graph(p::MaxPool, x) =
graph(op::Op, xs...) = op.f(xs...)
function graph(ctx::Context, model, args...)
node = graph(model, interpret(ctx, args)...)
node = graph(model, interpv(ctx, args)...)
isa(node, Tensor) && (ctx[:stacks][node.op.name] = stack(ctx))
return node
end
interp(ctx, c::Conv2D, x) =
nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
nn.conv2d(interpv(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
haskey(ctx[:params], p.value) ?
@ -57,7 +57,7 @@ function interp(ctx, model, args...)
g = Flux.graph(model)
g == nothing && return graph(ctx, model, args...)
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
interpret(ctx, g, interpret(ctx, args)...)
interpret(ctx, g, interpv(ctx, args)...)
end
function tograph(model, args...)

View File

@ -1,4 +1,4 @@
using DataFlow: interpret, interpret, interptuple, interplambda, interpconst, Context
using DataFlow: interpret, interpv, interptuple, interplambda, interpconst, Context
function astuple(xs::Vertex)
isconstant(xs) && isa(value(xs).value, Tuple) ? value(xs).value :
@ -17,7 +17,7 @@ end
function interpmap(cb)
function interp(ctx, ::typeof(map), f, xs...)
f, xs = interpret(ctx, (f, xs))
f, xs = interpv(ctx, (f, xs))
xs = astuples(xs)
xs nothing ?
group(map(f, xs...)...) :
@ -28,7 +28,7 @@ end
function interp(ctx, model, xs...)
g = graph(model)
g == nothing && return vertex(model, map(constant, interpret(ctx, xs))...)
g == nothing && return vertex(model, map(constant, interpv(ctx, xs))...)
interpret(ctx, g, xs...)
end