replace old shape inference system

This commit is contained in:
Mike J Innes 2017-03-17 16:34:51 +00:00
parent 146e3f8dc3
commit d73e962da9
10 changed files with 65 additions and 94 deletions

View File

@ -4,7 +4,7 @@ data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
train = data[1:50_000]
test = data[50_001:60_000]
m = Chain(
m = @Chain(
Input(784),
Affine(128), relu,
Affine( 64), relu,

View File

@ -3,17 +3,16 @@ using Flux, MXNet
Flux.loadmx()
conv1 = Chain(
Input(28,28),
Conv2D((5,5), out = 20), tanh,
MaxPool((2,2), stride = (2,2)))
conv2 = Chain(
conv1,
Conv2D((5,5), in = 20, out = 50), tanh,
MaxPool((2,2), stride = (2,2)))
lenet = Chain(
conv2,
lenet = @Chain(
Input(28,28,1),
conv1, conv2,
flatten,
Affine(500), tanh,
Affine(10), softmax)

View File

@ -1,17 +1,17 @@
using Flux, Juno
conv1 = Chain(
Reshape(28,28,1),
Conv2D((5,5), out = 20), tanh,
MaxPool((2,2), stride = (2,2)))
conv2 = Chain(
Input(12,12,20),
Conv2D((5,5), in = 20, out = 50), tanh,
MaxPool((2,2), stride = (2,2)))
lenet = Chain(
conv1, conv2, flatten,
lenet = @Chain(
Input(28,28,1),
conv1, conv2,
flatten,
Affine(500), tanh,
Affine(10), softmax)

View File

@ -7,6 +7,7 @@ using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
iscyclic, Constant, constant, isconstant, group, Split, splitnode,
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
spliceinputs, bumpinputs, Line, Frame, applylines
using MacroTools: @q
using Juno: Tree, Row
# Zero Flux Given
@ -27,7 +28,6 @@ include("compiler/shape.jl")
include("layers/affine.jl")
include("layers/activation.jl")
include("layers/recurrent.jl")
include("layers/shape.jl")
include("layers/chain.jl")
include("layers/shims.jl")

View File

@ -2,6 +2,8 @@ using DataFlow.Interpreter
export @shapes
Dims{N} = NTuple{N,Int}
struct Hint
typ
end
@ -41,7 +43,9 @@ shapes(args...) = shapesv(args...) |> syntax |> applylines |> (x->prettify(x, li
infer(f, args...) = graph(f) == nothing ? nothing : gethint(shapesv(f, args...))
function infer(::typeof(*), a::NTuple{2}, b::NTuple{2})
infer(::typeof(identity), x) = x
function infer(::typeof(*), a::Dims{2}, b::Dims{2})
a[2] == b[1] || return nothing
(a[1], b[2])
end
@ -55,3 +59,17 @@ macro shapes(ex)
@capture(ex, f_(args__)) || error("@shapes f(args...)")
:(shapes($(esc(f)), mapt(size, ($(map(esc, args)...),))...))
end
# Shim for kicking off shape inference
export Input
struct Input{N} <: Model
dims::Dims{N}
end
Input(i...) = Input((i...,))
(::Input)(x) = x
inferchain(f::Input, xs) = (-1, f.dims...)

View File

@ -8,11 +8,14 @@ relu(x) = max(0, x)
back!(::typeof(relu), Δ, x) = Δ .* (x .< 0)
# TODO: correct behaviour with batches
softmax(xs) = exp.(xs) ./ sum(exp.(xs))
# TODO: correct behaviour with batches
flatten(xs) = reshape(xs, length(xs))
shape(::typeof(flatten), in) = prod(in)
infer(::typeof(softmax), x) = x
infer(::typeof(tanh), x) = x
infer(::typeof(σ), x) = x
infer(::typeof(flatten), x::Dims) = (x[1], prod(x[2:end])...)

View File

@ -8,3 +8,6 @@ end
Affine(in::Integer, out::Integer; init = initn) =
Affine(init(in, out), init(1, out))
inferred(::Type{Affine}, in::Tuple{Dims{2}}, out::Integer) =
Affine(in[1][2], out)

View File

@ -1,26 +1,11 @@
export Chain
function inferchain(ms)
chain = []
sh = nothing
for m in ms
m = init(m, single(sh))
sh = shape(m, sh)
push!(chain, m)
end
return chain, sh
end
export Chain, @Chain
type Chain <: Model
layers::Vector{Any}
shape
function Chain(ms...)
ms, shape = inferchain(ms)
return new(ms, shape)
end
Chain(xs...) = new([xs...])
end
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
@forward Chain.layers Base.start, Base.next, Base.done
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
@ -30,6 +15,25 @@ update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
graph(s::Chain) =
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
shape(c::Chain, in) = c.shape
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
# Chain Macros
inferred(f, in, args...; kws...) = f(args...; kws...)
# `inferchain` allows for overriding inference behaviour for convenience.
# For example, `infer(Affine(10, 20), nothing)` would normally return a shape
# error, but for the interface we just ignore any errors and return (1, 20).
inferchain(f, xs...) = infer(f, xs...)
macro Chain(x, xs...)
inferconstructor(x) =
@capture(x, f_(xs__)) ? :(inferred($(esc(f)), (shape,), $(esc.(xs)...))) : esc(x)
@q let
shape = nothing
c = Chain($(esc(x)))
$([:(shape = inferchain(c.layers[end], shape);
push!(c, $x)) for x in inferconstructor.(xs)]...)
c
end
end

View File

@ -1,47 +0,0 @@
export Input
Dims{N} = NTuple{N,Int}
dims(d::Dims) = d
dims(i...) = (i...,)
single(i) = i
single(i::Dims) = length(i) == 1 ? first(i) : i
# Shim for kicking off shape inference
struct ShapeError <: Exception
layer
shape
end
struct Input{N} <: Model
dims::Dims{N}
end
Input(i...) = Input(dims(i...))
(::Input)(x) = x
back!(::Input, Δ, x) = Δ
# Initialise placeholder
struct Init{F}
f::F
end
init(i::Init, input...) = i.f(input...)
init(m, input...) = m
# Shape inference API
shape(x, in) = in
shape(i::Input, _) = i.dims
# Implementation for bundled layers
shape(d::Affine, _) = length(state(d.b)) # TODO: could perhaps infer this
Affine(out::Integer) = Init(in::Integer -> Affine(in, out))

View File

@ -8,11 +8,10 @@ end
Conv2D(size; in = 1, out = 1, stride = (1,1), init = initn) =
Conv2D(param(initn(size..., in, out)), stride)
shape(c::Conv2D, in::Dims{2}) =
(map(i -> (in[i]-size(c.filter,i))÷c.stride[i]+1, (1,2))..., size(c.filter, 4))
infer(c::Conv2D, in::Dims{4}) =
(in[1], map(i -> (in[i+1]-size(c.filter,i))÷c.stride[i]+1, (1,2))..., size(c.filter, 4))
shape(c::Conv2D, in::Dims{3}) =
shape(c, (in[1],in[2]))
# TODO: many of these should just be functions
for Pool in :[MaxPool, AvgPool].args
@eval begin
@ -24,11 +23,8 @@ for Pool in :[MaxPool, AvgPool].args
$Pool(size; stride = (1,1)) =
$Pool(size, stride)
shape(c::$Pool, in::Dims{2}) =
map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2))
shape(c::$Pool, in::Dims{3}) =
(shape(c, (in[1],in[2]))..., in[3])
infer(c::$Pool, in::Dims{4}) =
(in[1], map(i -> (in[i+1]-c.size[i])÷c.stride[i]+1, (1,2))..., in[4])
shape(c::$Pool) = nothing
end
@ -40,9 +36,4 @@ end
Reshape(dims::Integer...) = Reshape(dims)
function shape(r::Reshape, dims)
prod(dims) == prod(r.dims) || throw(ShapeError(r, dims))
return r.dims
end
shape(r::Reshape, ::Void) = r.dims