fix basic interpreters
This commit is contained in:
parent
3532c7174f
commit
f7eb5179b1
@ -12,7 +12,7 @@ end
|
|||||||
|
|
||||||
function makegraph(graph, args, params = [])
|
function makegraph(graph, args, params = [])
|
||||||
graph = prewalk(graph) do v
|
graph = prewalk(graph) do v
|
||||||
value(v) isa Constant && (i = findfirst(args, value(v).value)) ≠ 0 ?
|
isconstant(v) && (i = findfirst(args, value(v[1]))) ≠ 0 ?
|
||||||
inputnode(i) :
|
inputnode(i) :
|
||||||
v
|
v
|
||||||
end
|
end
|
||||||
@ -42,7 +42,7 @@ end
|
|||||||
|
|
||||||
function deref_params(v)
|
function deref_params(v)
|
||||||
map(v) do x
|
map(v) do x
|
||||||
x isa Constant && @capture(x.value, self.p_) ? Constant(:(Flux.state(self.$p))) : x
|
@capture(x, self.p_) ? :(Flux.state(self.$p)) : x
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
function astuple(xs::Vertex)
|
function astuple(xs::Vertex)
|
||||||
isconstant(xs) && value(xs).value isa Tuple ? value(xs).value :
|
isconstant(xs) && value(xs[1]) isa Tuple ? value(xs[1]) :
|
||||||
xs isa Vertex && value(xs) == tuple ? inputs(xs) :
|
xs isa Vertex && value(xs) == tuple ? inputs(xs) :
|
||||||
nothing
|
nothing
|
||||||
end
|
end
|
||||||
@ -21,10 +21,7 @@ function interp(ctx, f, xs...)
|
|||||||
f(xs...))
|
f(xs...))
|
||||||
end
|
end
|
||||||
|
|
||||||
interp(ctx::Context, c::Constant{<:Param}) = c.value.x
|
|
||||||
interp(ctx::Context, c::Constant) = c.value
|
|
||||||
|
|
||||||
function interpmodel(m, args...)
|
function interpmodel(m, args...)
|
||||||
ctx = Context(mux(iline, ilambda, iargs, ituple, interp))
|
ctx = Context(mux(iconst, iline, ilambda, iargs, ituple, interp))
|
||||||
@ithrow interp(ctx, m, args...)
|
@ithrow interp(ctx, m, args...)
|
||||||
end
|
end
|
||||||
|
@ -8,6 +8,10 @@ end
|
|||||||
|
|
||||||
DataFlow.tocall(h::Hint, x) = :($x::$(h.typ))
|
DataFlow.tocall(h::Hint, x) = :($x::$(h.typ))
|
||||||
|
|
||||||
|
arghint(p::Param) = arghint(state(p))
|
||||||
|
arghint(xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
|
||||||
|
arghint(x) = constant(x)
|
||||||
|
|
||||||
function gethint(v::IVertex)
|
function gethint(v::IVertex)
|
||||||
while value(v) isa Union{Line,Frame} v = v[1] end
|
while value(v) isa Union{Line,Frame} v = v[1] end
|
||||||
value(v) isa Hint && return value(v).typ
|
value(v) isa Hint && return value(v).typ
|
||||||
@ -17,19 +21,16 @@ end
|
|||||||
ihint(f, ctx::Context, h::Hint, x) = vertex(h, x)
|
ihint(f, ctx::Context, h::Hint, x) = vertex(h, x)
|
||||||
ihint(f, args...) = f(args...)
|
ihint(f, args...) = f(args...)
|
||||||
|
|
||||||
hintify(ctx, c::Constant{<:Union{Param,AbstractArray}}) = hintify(ctx, state(c.value))
|
|
||||||
hintify(ctx, xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
|
|
||||||
hintify(ctx, c::Constant) = vertex(c)
|
|
||||||
|
|
||||||
interpshape = mux(ilinev, ihint, iargs, hintify)
|
|
||||||
|
|
||||||
function hintify(ctx, f, xs...)
|
function hintify(ctx, f, xs...)
|
||||||
sh = infer(f, gethint.(xs)...)
|
xs = arghint.(xs)
|
||||||
|
sh = infer(f, map(gethint, xs)...)
|
||||||
sh ≠ nothing ? vertex(Hint(sh), vertex(f, xs...)) :
|
sh ≠ nothing ? vertex(Hint(sh), vertex(f, xs...)) :
|
||||||
!any(x->x==nothing, xs) && graph(f) ≠ nothing ? interpret(Context(interpshape), graph(f), xs...) :
|
!any(x->x==nothing, xs) && graph(f) ≠ nothing ? interpret(Context(interpshape), graph(f), xs...) :
|
||||||
vertex(f, xs...)
|
vertex(f, xs...)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
interpshape = mux(ilinev, iconst, ihint, iargs, hintify)
|
||||||
|
|
||||||
function shapesv(f, args...)
|
function shapesv(f, args...)
|
||||||
(g = graph(f)) == nothing && return
|
(g = graph(f)) == nothing && return
|
||||||
ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)]
|
ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)]
|
||||||
|
Loading…
Reference in New Issue
Block a user