From f7eb5179b1995b6d0bc782c6eda60782689d6ab6 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 22 May 2017 17:39:08 +0100 Subject: [PATCH] fix basic interpreters --- src/compiler/code.jl | 4 ++-- src/compiler/interp.jl | 7 ++----- src/compiler/shape.jl | 15 ++++++++------- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/compiler/code.jl b/src/compiler/code.jl index d3a1f688..f9192a90 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -12,7 +12,7 @@ end function makegraph(graph, args, params = []) graph = prewalk(graph) do v - value(v) isa Constant && (i = findfirst(args, value(v).value)) ≠ 0 ? + isconstant(v) && (i = findfirst(args, value(v[1]))) ≠ 0 ? inputnode(i) : v end @@ -42,7 +42,7 @@ end function deref_params(v) map(v) do x - x isa Constant && @capture(x.value, self.p_) ? Constant(:(Flux.state(self.$p))) : x + @capture(x, self.p_) ? :(Flux.state(self.$p)) : x end end diff --git a/src/compiler/interp.jl b/src/compiler/interp.jl index 2e307006..6b9a022d 100644 --- a/src/compiler/interp.jl +++ b/src/compiler/interp.jl @@ -1,5 +1,5 @@ function astuple(xs::Vertex) - isconstant(xs) && value(xs).value isa Tuple ? value(xs).value : + isconstant(xs) && value(xs[1]) isa Tuple ? value(xs[1]) : xs isa Vertex && value(xs) == tuple ? inputs(xs) : nothing end @@ -21,10 +21,7 @@ function interp(ctx, f, xs...) f(xs...)) end -interp(ctx::Context, c::Constant{<:Param}) = c.value.x -interp(ctx::Context, c::Constant) = c.value - function interpmodel(m, args...) - ctx = Context(mux(iline, ilambda, iargs, ituple, interp)) + ctx = Context(mux(iconst, iline, ilambda, iargs, ituple, interp)) @ithrow interp(ctx, m, args...) end diff --git a/src/compiler/shape.jl b/src/compiler/shape.jl index a944cc62..7de80e61 100644 --- a/src/compiler/shape.jl +++ b/src/compiler/shape.jl @@ -8,6 +8,10 @@ end DataFlow.tocall(h::Hint, x) = :($x::$(h.typ)) +arghint(p::Param) = arghint(state(p)) +arghint(xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_)) +arghint(x) = constant(x) + function gethint(v::IVertex) while value(v) isa Union{Line,Frame} v = v[1] end value(v) isa Hint && return value(v).typ @@ -17,19 +21,16 @@ end ihint(f, ctx::Context, h::Hint, x) = vertex(h, x) ihint(f, args...) = f(args...) -hintify(ctx, c::Constant{<:Union{Param,AbstractArray}}) = hintify(ctx, state(c.value)) -hintify(ctx, xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_)) -hintify(ctx, c::Constant) = vertex(c) - -interpshape = mux(ilinev, ihint, iargs, hintify) - function hintify(ctx, f, xs...) - sh = infer(f, gethint.(xs)...) + xs = arghint.(xs) + sh = infer(f, map(gethint, xs)...) sh ≠ nothing ? vertex(Hint(sh), vertex(f, xs...)) : !any(x->x==nothing, xs) && graph(f) ≠ nothing ? interpret(Context(interpshape), graph(f), xs...) : vertex(f, xs...) end +interpshape = mux(ilinev, iconst, ihint, iargs, hintify) + function shapesv(f, args...) (g = graph(f)) == nothing && return ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)]