split interpret / interpv

This commit is contained in:
Mike J Innes 2016-12-21 13:05:18 +00:00
parent 6acfcd913e
commit 353f156354
2 changed files with 7 additions and 7 deletions

View File

@ -1,6 +1,6 @@
using Base: @get! using Base: @get!
using DataFlow: Constant, constant, Context, interpret, Split, interptuple, using DataFlow: Constant, constant, Context, interpret, Split, interptuple,
interplambda, interpconst, interpline, stack interpv, interplambda, interpconst, interpline, stack
using Flux: interpmap using Flux: interpmap
using TensorFlow: RawTensor using TensorFlow: RawTensor
@ -38,13 +38,13 @@ graph(p::MaxPool, x) =
graph(op::Op, xs...) = op.f(xs...) graph(op::Op, xs...) = op.f(xs...)
function graph(ctx::Context, model, args...) function graph(ctx::Context, model, args...)
node = graph(model, interpret(ctx, args)...) node = graph(model, interpv(ctx, args)...)
isa(node, Tensor) && (ctx[:stacks][node.op.name] = stack(ctx)) isa(node, Tensor) && (ctx[:stacks][node.op.name] = stack(ctx))
return node return node
end end
interp(ctx, c::Conv2D, x) = interp(ctx, c::Conv2D, x) =
nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID") nn.conv2d(interpv(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) = interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
haskey(ctx[:params], p.value) ? haskey(ctx[:params], p.value) ?
@ -57,7 +57,7 @@ function interp(ctx, model, args...)
g = Flux.graph(model) g = Flux.graph(model)
g == nothing && return graph(ctx, model, args...) g == nothing && return graph(ctx, model, args...)
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.") DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
interpret(ctx, g, interpret(ctx, args)...) interpret(ctx, g, interpv(ctx, args)...)
end end
function tograph(model, args...) function tograph(model, args...)

View File

@ -1,4 +1,4 @@
using DataFlow: interpret, interpret, interptuple, interplambda, interpconst, Context using DataFlow: interpret, interpv, interptuple, interplambda, interpconst, Context
function astuple(xs::Vertex) function astuple(xs::Vertex)
isconstant(xs) && isa(value(xs).value, Tuple) ? value(xs).value : isconstant(xs) && isa(value(xs).value, Tuple) ? value(xs).value :
@ -17,7 +17,7 @@ end
function interpmap(cb) function interpmap(cb)
function interp(ctx, ::typeof(map), f, xs...) function interp(ctx, ::typeof(map), f, xs...)
f, xs = interpret(ctx, (f, xs)) f, xs = interpv(ctx, (f, xs))
xs = astuples(xs) xs = astuples(xs)
xs nothing ? xs nothing ?
group(map(f, xs...)...) : group(map(f, xs...)...) :
@ -28,7 +28,7 @@ end
function interp(ctx, model, xs...) function interp(ctx, model, xs...)
g = graph(model) g = graph(model)
g == nothing && return vertex(model, map(constant, interpret(ctx, xs))...) g == nothing && return vertex(model, map(constant, interpv(ctx, xs))...)
interpret(ctx, g, xs...) interpret(ctx, g, xs...)
end end