split interpret / interpv
This commit is contained in:
parent
6acfcd913e
commit
353f156354
@ -1,6 +1,6 @@
|
|||||||
using Base: @get!
|
using Base: @get!
|
||||||
using DataFlow: Constant, constant, Context, interpret, Split, interptuple,
|
using DataFlow: Constant, constant, Context, interpret, Split, interptuple,
|
||||||
interplambda, interpconst, interpline, stack
|
interpv, interplambda, interpconst, interpline, stack
|
||||||
using Flux: interpmap
|
using Flux: interpmap
|
||||||
using TensorFlow: RawTensor
|
using TensorFlow: RawTensor
|
||||||
|
|
||||||
@ -38,13 +38,13 @@ graph(p::MaxPool, x) =
|
|||||||
graph(op::Op, xs...) = op.f(xs...)
|
graph(op::Op, xs...) = op.f(xs...)
|
||||||
|
|
||||||
function graph(ctx::Context, model, args...)
|
function graph(ctx::Context, model, args...)
|
||||||
node = graph(model, interpret(ctx, args)...)
|
node = graph(model, interpv(ctx, args)...)
|
||||||
isa(node, Tensor) && (ctx[:stacks][node.op.name] = stack(ctx))
|
isa(node, Tensor) && (ctx[:stacks][node.op.name] = stack(ctx))
|
||||||
return node
|
return node
|
||||||
end
|
end
|
||||||
|
|
||||||
interp(ctx, c::Conv2D, x) =
|
interp(ctx, c::Conv2D, x) =
|
||||||
nn.conv2d(interpret(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
|
nn.conv2d(interpv(ctx, x), interp(ctx, Constant(c.filter)), [1,c.stride...,1], "VALID")
|
||||||
|
|
||||||
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
|
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
|
||||||
haskey(ctx[:params], p.value) ?
|
haskey(ctx[:params], p.value) ?
|
||||||
@ -57,7 +57,7 @@ function interp(ctx, model, args...)
|
|||||||
g = Flux.graph(model)
|
g = Flux.graph(model)
|
||||||
g == nothing && return graph(ctx, model, args...)
|
g == nothing && return graph(ctx, model, args...)
|
||||||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||||
interpret(ctx, g, interpret(ctx, args)...)
|
interpret(ctx, g, interpv(ctx, args)...)
|
||||||
end
|
end
|
||||||
|
|
||||||
function tograph(model, args...)
|
function tograph(model, args...)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
using DataFlow: interpret, interpret, interptuple, interplambda, interpconst, Context
|
using DataFlow: interpret, interpv, interptuple, interplambda, interpconst, Context
|
||||||
|
|
||||||
function astuple(xs::Vertex)
|
function astuple(xs::Vertex)
|
||||||
isconstant(xs) && isa(value(xs).value, Tuple) ? value(xs).value :
|
isconstant(xs) && isa(value(xs).value, Tuple) ? value(xs).value :
|
||||||
@ -17,7 +17,7 @@ end
|
|||||||
|
|
||||||
function interpmap(cb)
|
function interpmap(cb)
|
||||||
function interp(ctx, ::typeof(map), f, xs...)
|
function interp(ctx, ::typeof(map), f, xs...)
|
||||||
f, xs = interpret(ctx, (f, xs))
|
f, xs = interpv(ctx, (f, xs))
|
||||||
xs′ = astuples(xs)
|
xs′ = astuples(xs)
|
||||||
xs′ ≠ nothing ?
|
xs′ ≠ nothing ?
|
||||||
group(map(f, xs′...)...) :
|
group(map(f, xs′...)...) :
|
||||||
@ -28,7 +28,7 @@ end
|
|||||||
|
|
||||||
function interp(ctx, model, xs...)
|
function interp(ctx, model, xs...)
|
||||||
g = graph(model)
|
g = graph(model)
|
||||||
g == nothing && return vertex(model, map(constant, interpret(ctx, xs))...)
|
g == nothing && return vertex(model, map(constant, interpv(ctx, xs))...)
|
||||||
interpret(ctx, g, xs...)
|
interpret(ctx, g, xs...)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user