remove inference, for now
This commit is contained in:
parent
e79a1657d4
commit
052cc52ada
@ -25,7 +25,6 @@ include("params.jl")
|
||||
include("compiler/code.jl")
|
||||
include("compiler/loops.jl")
|
||||
include("compiler/interp.jl")
|
||||
include("compiler/shape.jl")
|
||||
|
||||
include("layers/control.jl")
|
||||
include("layers/affine.jl")
|
||||
|
@ -1,74 +0,0 @@
|
||||
Dims{N} = NTuple{N,Int}
|
||||
|
||||
struct Hint
|
||||
typ
|
||||
end
|
||||
|
||||
DataFlow.tocall(h::Hint, x) = :($x::$(h.typ))
|
||||
|
||||
arghint(p::Param) = arghint(state(p))
|
||||
arghint(xs::AbstractArray) = vertex(Hint(size(xs)), constant(:_))
|
||||
arghint(x) = constant(x)
|
||||
|
||||
function gethint(v::IVertex)
|
||||
while value(v) isa Union{Line,Frame} v = v[1] end
|
||||
value(v) isa Hint && return value(v).typ
|
||||
return
|
||||
end
|
||||
|
||||
ihint(f, ctx::Context, h::Hint, x) = vertex(h, x)
|
||||
ihint(f, args...) = f(args...)
|
||||
|
||||
function hintify(ctx, f, xs...)
|
||||
xs = arghint.(xs)
|
||||
sh = infer(f, map(gethint, xs)...)
|
||||
sh ≠ nothing ? vertex(Hint(sh), vertex(f, xs...)) :
|
||||
!any(x->x==nothing, xs) && graph(f) ≠ nothing ? interpret(Context(interpshape), graph(f), xs...) :
|
||||
vertex(f, xs...)
|
||||
end
|
||||
|
||||
interpshape = mux(ilinev, iconst, ihint, iargs, hintify)
|
||||
|
||||
function shapesv(f, args...)
|
||||
(g = graph(f)) == nothing && return
|
||||
ins = [vertex(Hint(d), inputnode(i)) for (i,d) in enumerate(args)]
|
||||
interpv(Context(interpshape), detuple(spliceinputs(g, ins...)))
|
||||
end
|
||||
|
||||
shapes(args...) = shapesv(args...) |> syntax |> applylines |> (x->prettify(x, lines=true))
|
||||
|
||||
# Inference primitives
|
||||
|
||||
infer(f, args...) = graph(f) == nothing ? nothing : gethint(shapesv(f, args...))
|
||||
|
||||
infer(::typeof(tuple), xs...) = (xs...,)
|
||||
infer(s::Split, xs::Tuple) = 1 ≤ s.n ≤ length(xs) ? xs[s.n] : nothing
|
||||
infer(::typeof(identity), x) = x
|
||||
|
||||
function infer(::typeof(*), a::Dims{2}, b::Dims{2})
|
||||
a[2] == b[1] || return nothing
|
||||
(a[1], b[2])
|
||||
end
|
||||
|
||||
infer(::typeof(broadcast), f, xs::Dims...) = Base.Broadcast.broadcast_shape(xs...)
|
||||
# Old broadcast versions
|
||||
infer(::typeof(.+), xs::Dims...) = Base.Broadcast.broadcast_shape(xs...)
|
||||
|
||||
# Shapes macro
|
||||
|
||||
macro shapes(ex)
|
||||
@capture(ex, f_(args__)) || error("@shapes f(args...)")
|
||||
:(shapes($(esc(f)), mapt(size, ($(map(esc, args)...),))...))
|
||||
end
|
||||
|
||||
# Shim for kicking off shape inference
|
||||
|
||||
struct Input{N}
|
||||
dims::Dims{N}
|
||||
end
|
||||
|
||||
Input(i...) = Input((i...,))
|
||||
|
||||
(::Input)(x) = x
|
||||
|
||||
inferchain(f::Input, xs) = (-1, f.dims...)
|
@ -9,10 +9,3 @@ back!(::typeof(relu), Δ, x) = Δ .* (x .> 0)
|
||||
softmax(xs) = exp.(xs) ./ sum(exp.(xs), 2)
|
||||
|
||||
flatten(xs) = reshape(xs, size(xs, 1), :)
|
||||
|
||||
infer(::typeof(softmax), x) = x
|
||||
infer(::typeof(tanh), x) = x
|
||||
infer(::typeof(relu), x) = x
|
||||
infer(::typeof(σ), x) = x
|
||||
|
||||
infer(::typeof(flatten), x::Dims) = (x[1], prod(x[2:end])...)
|
||||
|
@ -7,9 +7,6 @@ end
|
||||
Affine(in::Integer, out::Integer; init = initn) =
|
||||
Affine(init(in, out), init(1, out))
|
||||
|
||||
inferred(::Type{Affine}, in::Tuple{Dims{2}}, out::Integer) =
|
||||
Affine(in[1][2], out)
|
||||
|
||||
function back!(m::Affine, Δ, x)
|
||||
W, b = m.W, m.b
|
||||
W.Δx[:] = x' * Δ
|
||||
|
@ -24,24 +24,3 @@ graph(s::Chain) =
|
||||
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
|
||||
|
||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||
|
||||
# Chain Macros
|
||||
|
||||
inferred(f, in, args...; kws...) = f(args...; kws...)
|
||||
|
||||
# `inferchain` allows for overriding inference behaviour for convenience.
|
||||
# For example, `infer(Affine(10, 20), nothing)` would normally return a shape
|
||||
# error, but for the interface we just ignore any errors and return (1, 20).
|
||||
inferchain(f, xs...) = infer(f, xs...)
|
||||
|
||||
macro Chain(x, xs...)
|
||||
inferconstructor(x) =
|
||||
@capture(x, f_(xs__)) ? :(inferred($(esc(f)), (shape,), $(esc.(xs)...))) : esc(x)
|
||||
@q let
|
||||
shape = nothing
|
||||
c = Chain($(esc(x)))
|
||||
$([:(shape = inferchain(c.layers[end], shape);
|
||||
push!(c, $x)) for x in inferconstructor.(xs)]...)
|
||||
c
|
||||
end
|
||||
end
|
||||
|
@ -10,8 +10,6 @@ d = Affine(10, 20)
|
||||
|
||||
d1 = @net x -> x * d.W + d.b
|
||||
|
||||
@test Flux.infer(d, (1, 10)) == (1,20)
|
||||
|
||||
# Skip this before new DataFlow is released.
|
||||
# let
|
||||
# @test @capture(syntax(d), _Frame(_Line((+).(x_[1] * W_, b_))))
|
||||
@ -24,7 +22,6 @@ let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
||||
tlp = TLP(a1, a2)
|
||||
@test tlp(xs) ≈ softmax(a2(σ(a1(xs))))
|
||||
@test Flux.interpmodel(tlp, xs) ≈ softmax(a2(σ(a1(xs))))
|
||||
@test Flux.infer(tlp, (1, 10)) == (1,15)
|
||||
end
|
||||
|
||||
let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
||||
|
Loading…
Reference in New Issue
Block a user