split interpret / interpv
This commit is contained in:
parent
6acfcd913e
commit
353f156354
@ -1,6 +1,6 @@
|
||||
using Base: @get!
|
||||
using DataFlow: Constant, constant, Context, interpret, Split, interptuple,
|
||||
interplambda, interpconst, interpline, stack
|
||||
interpv, interplambda, interpconst, interpline, stack
|
||||
using Flux: interpmap
|
||||
using TensorFlow: RawTensor
|
||||
|
||||
@ -38,13 +38,13 @@ graph(p::MaxPool, x) =
|
||||
graph(op::Op, xs...) = op.f(xs...)
|
||||
|
||||
function graph(ctx::Context, model, args...)
|
||||
node = graph(model, interpret(ctx, args)...)
|
||||
node = graph(model, interpv(ctx, args)...)
|
||||
isa(node, Tensor) && (ctx[:stacks][node.op.name] = stack(ctx))
|
||||
return node
|
||||
end
|
||||
|
||||
interp(ctx, c::Conv2D, x) =
|
||||
nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
|
||||
nn.conv2d(interpv(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
|
||||
|
||||
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
|
||||
haskey(ctx[:params], p.value) ?
|
||||
@ -57,7 +57,7 @@ function interp(ctx, model, args...)
|
||||
g = Flux.graph(model)
|
||||
g == nothing && return graph(ctx, model, args...)
|
||||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||
interpret(ctx, g, interpret(ctx, args)...)
|
||||
interpret(ctx, g, interpv(ctx, args)...)
|
||||
end
|
||||
|
||||
function tograph(model, args...)
|
||||
|
@ -1,4 +1,4 @@
|
||||
using DataFlow: interpret, interpret, interptuple, interplambda, interpconst, Context
|
||||
using DataFlow: interpret, interpv, interptuple, interplambda, interpconst, Context
|
||||
|
||||
function astuple(xs::Vertex)
|
||||
isconstant(xs) && isa(value(xs).value, Tuple) ? value(xs).value :
|
||||
@ -17,7 +17,7 @@ end
|
||||
|
||||
function interpmap(cb)
|
||||
function interp(ctx, ::typeof(map), f, xs...)
|
||||
f, xs = interpret(ctx, (f, xs))
|
||||
f, xs = interpv(ctx, (f, xs))
|
||||
xs′ = astuples(xs)
|
||||
xs′ ≠ nothing ?
|
||||
group(map(f, xs′...)...) :
|
||||
@ -28,7 +28,7 @@ end
|
||||
|
||||
function interp(ctx, model, xs...)
|
||||
g = graph(model)
|
||||
g == nothing && return vertex(model, map(constant, interpret(ctx, xs))...)
|
||||
g == nothing && return vertex(model, map(constant, interpv(ctx, xs))...)
|
||||
interpret(ctx, g, xs...)
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user