fix basic interpreters

This commit is contained in:
Mike J Innes 2017-05-22 17:39:08 +01:00
parent 3532c7174f
commit f7eb5179b1
3 changed files with 12 additions and 14 deletions

View File

@ -12,7 +12,7 @@ end
function makegraph(graph, args, params = [])
graph = prewalk(graph) do v
value(v) isa Constant && (i = findfirst(args, value(v).value)) 0 ?
isconstant(v) && (i = findfirst(args, value(v[1]))) 0 ?
inputnode(i) :
v
end
@ -42,7 +42,7 @@ end
function deref_params(v)
map(v) do x
x isa Constant && @capture(x.value, self.p_) ? Constant(:(Flux.state(self.$p))) : x
@capture(x, self.p_) ? :(Flux.state(self.$p)) : x
end
end

View File

@ -1,5 +1,5 @@
function astuple(xs::Vertex)
isconstant(xs) && value(xs).value isa Tuple ? value(xs).value :
isconstant(xs) && value(xs[1]) isa Tuple ? value(xs[1]) :
xs isa Vertex && value(xs) == tuple ? inputs(xs) :
nothing
end
@ -21,10 +21,7 @@ function interp(ctx, f, xs...)
f(xs...))
end
interp(ctx::Context, c::Constant{<:Param}) = c.value.x
interp(ctx::Context, c::Constant) = c.value
function interpmodel(m, args...)
ctx = Context(mux(iline, ilambda, iargs, ituple, interp))
ctx = Context(mux(iconst, iline, ilambda, iargs, ituple, interp))
@ithrow interp(ctx, m, args...)
end

View File

@ -8,6 +8,10 @@ end
DataFlow.tocall(h::Hint, x) = :($x::$(h.typ))
arghint(p::Param) = arghint(state(p))
arghint(xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
arghint(x) = constant(x)
function gethint(v::IVertex)
while value(v) isa Union{Line,Frame} v = v[1] end
value(v) isa Hint && return value(v).typ
@ -17,19 +21,16 @@ end
ihint(f, ctx::Context, h::Hint, x) = vertex(h, x)
ihint(f, args...) = f(args...)
hintify(ctx, c::Constant{<:Union{Param,AbstractArray}}) = hintify(ctx, state(c.value))
hintify(ctx, xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
hintify(ctx, c::Constant) = vertex(c)
interpshape = mux(ilinev, ihint, iargs, hintify)
function hintify(ctx, f, xs...)
sh = infer(f, gethint.(xs)...)
xs = arghint.(xs)
sh = infer(f, map(gethint, xs)...)
sh nothing ? vertex(Hint(sh), vertex(f, xs...)) :
!any(x->x==nothing, xs) && graph(f) nothing ? interpret(Context(interpshape), graph(f), xs...) :
vertex(f, xs...)
end
interpshape = mux(ilinev, iconst, ihint, iargs, hintify)
function shapesv(f, args...)
(g = graph(f)) == nothing && return
ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)]