shape debugger

This commit is contained in:
Mike J Innes 2016-12-26 18:55:43 +00:00
parent 87e928638a
commit 8d1171cb96
2 changed files with 48 additions and 0 deletions

View File

@ -16,6 +16,7 @@ include("data.jl")
include("compiler/code.jl")
include("compiler/loops.jl")
include("compiler/interp.jl")
include("compiler/shape.jl")
include("layers/affine.jl")
include("layers/activation.jl")

47
src/compiler/shape.jl Normal file
View File

@ -0,0 +1,47 @@
using DataFlow: iline, iargs
type Hint
typ
end
DataFlow.tocall(h::Hint, x) = :($x::$(h.typ))
function gethint(v::IVertex)
isa(value(v), Hint) && return value(v).typ
return
end
ihint(f, ctx::Context, h::Hint, x) = vertex(h, x)
ihint(f, args...) = f(args...)
hintify(c::Constant) = hintify(state(c.value))
hintify(xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
interpshape = mux(iline, ihint, iargs, ituple, hintify)
function hintify(f, xs...)
sh = infer(f, map(gethint, xs)...)
sh nothing ? vertex(Hint(sh), vertex(f, xs...)) :
!any(x->x==nothing, xs) && graph(f) nothing ? interpret(Context(interpshape), graph(f), xs...) :
vertex(f, xs...)
end
function shapes(f, args...)
(g = graph(f)) == nothing && return
ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)]
interpret(Context(interpshape), g, ins...)
end
# Inference primitives
infer(f, args...) = graph(f) == nothing ? nothing : gethint(shapes(f, args...))
function infer(::typeof(*), a::NTuple{2}, b::NTuple{2})
a[2] == b[1] || return nothing
(a[1], b[2])
end
# TODO: make correct
infer(::typeof(+), a, b) = a
infer(::typeof(softmax), x) = x
infer(::typeof(σ), x) = x