From d73e962da9c521712c402e3b99a1a35741dcc35e Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 17 Mar 2017 16:34:51 +0000 Subject: [PATCH] replace old shape inference system --- examples/MNIST.jl | 2 +- examples/integration-mx.jl | 7 +++--- examples/integration-tf.jl | 8 +++---- src/Flux.jl | 2 +- src/compiler/shape.jl | 20 +++++++++++++++- src/layers/activation.jl | 7 ++++-- src/layers/affine.jl | 3 +++ src/layers/chain.jl | 44 +++++++++++++++++++---------------- src/layers/shape.jl | 47 -------------------------------------- src/layers/shims.jl | 19 ++++----------- 10 files changed, 65 insertions(+), 94 deletions(-) delete mode 100644 src/layers/shape.jl diff --git a/examples/MNIST.jl b/examples/MNIST.jl index 4cd104db..9ab749a3 100644 --- a/examples/MNIST.jl +++ b/examples/MNIST.jl @@ -4,7 +4,7 @@ data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000] train = data[1:50_000] test = data[50_001:60_000] -m = Chain( +m = @Chain( Input(784), Affine(128), relu, Affine( 64), relu, diff --git a/examples/integration-mx.jl b/examples/integration-mx.jl index 2874e039..cb25c39c 100644 --- a/examples/integration-mx.jl +++ b/examples/integration-mx.jl @@ -3,17 +3,16 @@ using Flux, MXNet Flux.loadmx() conv1 = Chain( - Input(28,28), Conv2D((5,5), out = 20), tanh, MaxPool((2,2), stride = (2,2))) conv2 = Chain( - conv1, Conv2D((5,5), in = 20, out = 50), tanh, MaxPool((2,2), stride = (2,2))) -lenet = Chain( - conv2, +lenet = @Chain( + Input(28,28,1), + conv1, conv2, flatten, Affine(500), tanh, Affine(10), softmax) diff --git a/examples/integration-tf.jl b/examples/integration-tf.jl index f7d7b70a..12faa4d9 100644 --- a/examples/integration-tf.jl +++ b/examples/integration-tf.jl @@ -1,17 +1,17 @@ using Flux, Juno conv1 = Chain( - Reshape(28,28,1), Conv2D((5,5), out = 20), tanh, MaxPool((2,2), stride = (2,2))) conv2 = Chain( - Input(12,12,20), Conv2D((5,5), in = 20, out = 50), tanh, MaxPool((2,2), stride = (2,2))) -lenet = Chain( - conv1, conv2, flatten, +lenet = @Chain( + Input(28,28,1), + conv1, conv2, + flatten, Affine(500), tanh, Affine(10), softmax) diff --git a/src/Flux.jl b/src/Flux.jl index fb4de7d6..5972bd44 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,6 +7,7 @@ using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk, iscyclic, Constant, constant, isconstant, group, Split, splitnode, detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode, spliceinputs, bumpinputs, Line, Frame, applylines +using MacroTools: @q using Juno: Tree, Row # Zero Flux Given @@ -27,7 +28,6 @@ include("compiler/shape.jl") include("layers/affine.jl") include("layers/activation.jl") include("layers/recurrent.jl") -include("layers/shape.jl") include("layers/chain.jl") include("layers/shims.jl") diff --git a/src/compiler/shape.jl b/src/compiler/shape.jl index 3983ffa9..d4b3f284 100644 --- a/src/compiler/shape.jl +++ b/src/compiler/shape.jl @@ -2,6 +2,8 @@ using DataFlow.Interpreter export @shapes +Dims{N} = NTuple{N,Int} + struct Hint typ end @@ -41,7 +43,9 @@ shapes(args...) = shapesv(args...) |> syntax |> applylines |> (x->prettify(x, li infer(f, args...) = graph(f) == nothing ? nothing : gethint(shapesv(f, args...)) -function infer(::typeof(*), a::NTuple{2}, b::NTuple{2}) +infer(::typeof(identity), x) = x + +function infer(::typeof(*), a::Dims{2}, b::Dims{2}) a[2] == b[1] || return nothing (a[1], b[2]) end @@ -55,3 +59,17 @@ macro shapes(ex) @capture(ex, f_(args__)) || error("@shapes f(args...)") :(shapes($(esc(f)), mapt(size, ($(map(esc, args)...),))...)) end + +# Shim for kicking off shape inference + +export Input + +struct Input{N} <: Model + dims::Dims{N} +end + +Input(i...) = Input((i...,)) + +(::Input)(x) = x + +inferchain(f::Input, xs) = (-1, f.dims...) diff --git a/src/layers/activation.jl b/src/layers/activation.jl index 88cad04c..0ed8bb38 100644 --- a/src/layers/activation.jl +++ b/src/layers/activation.jl @@ -8,11 +8,14 @@ relu(x) = max(0, x) back!(::typeof(relu), Δ, x) = Δ .* (x .< 0) +# TODO: correct behaviour with batches softmax(xs) = exp.(xs) ./ sum(exp.(xs)) +# TODO: correct behaviour with batches flatten(xs) = reshape(xs, length(xs)) -shape(::typeof(flatten), in) = prod(in) - infer(::typeof(softmax), x) = x +infer(::typeof(tanh), x) = x infer(::typeof(σ), x) = x + +infer(::typeof(flatten), x::Dims) = (x[1], prod(x[2:end])...) diff --git a/src/layers/affine.jl b/src/layers/affine.jl index 087905a8..c447b606 100644 --- a/src/layers/affine.jl +++ b/src/layers/affine.jl @@ -8,3 +8,6 @@ end Affine(in::Integer, out::Integer; init = initn) = Affine(init(in, out), init(1, out)) + +inferred(::Type{Affine}, in::Tuple{Dims{2}}, out::Integer) = + Affine(in[1][2], out) diff --git a/src/layers/chain.jl b/src/layers/chain.jl index 43d09e4d..839bb58c 100644 --- a/src/layers/chain.jl +++ b/src/layers/chain.jl @@ -1,26 +1,11 @@ -export Chain - -function inferchain(ms) - chain = [] - sh = nothing - for m in ms - m = init(m, single(sh)) - sh = shape(m, sh) - push!(chain, m) - end - return chain, sh -end +export Chain, @Chain type Chain <: Model layers::Vector{Any} - shape - function Chain(ms...) - ms, shape = inferchain(ms) - return new(ms, shape) - end + Chain(xs...) = new([xs...]) end -@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof +@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push! @forward Chain.layers Base.start, Base.next, Base.done (s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) @@ -30,6 +15,25 @@ update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers) graph(s::Chain) = foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers) -shape(c::Chain, in) = c.shape - Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) + +# Chain Macros + +inferred(f, in, args...; kws...) = f(args...; kws...) + +# `inferchain` allows for overriding inference behaviour for convenience. +# For example, `infer(Affine(10, 20), nothing)` would normally return a shape +# error, but for the interface we just ignore any errors and return (1, 20). +inferchain(f, xs...) = infer(f, xs...) + +macro Chain(x, xs...) + inferconstructor(x) = + @capture(x, f_(xs__)) ? :(inferred($(esc(f)), (shape,), $(esc.(xs)...))) : esc(x) + @q let + shape = nothing + c = Chain($(esc(x))) + $([:(shape = inferchain(c.layers[end], shape); + push!(c, $x)) for x in inferconstructor.(xs)]...) + c + end +end diff --git a/src/layers/shape.jl b/src/layers/shape.jl deleted file mode 100644 index 1c155ed6..00000000 --- a/src/layers/shape.jl +++ /dev/null @@ -1,47 +0,0 @@ -export Input - -Dims{N} = NTuple{N,Int} - -dims(d::Dims) = d - -dims(i...) = (i...,) - -single(i) = i -single(i::Dims) = length(i) == 1 ? first(i) : i - -# Shim for kicking off shape inference - -struct ShapeError <: Exception - layer - shape -end - -struct Input{N} <: Model - dims::Dims{N} -end - -Input(i...) = Input(dims(i...)) - -(::Input)(x) = x -back!(::Input, Δ, x) = Δ - -# Initialise placeholder - -struct Init{F} - f::F -end - -init(i::Init, input...) = i.f(input...) -init(m, input...) = m - -# Shape inference API - -shape(x, in) = in - -shape(i::Input, _) = i.dims - -# Implementation for bundled layers - -shape(d::Affine, _) = length(state(d.b)) # TODO: could perhaps infer this - -Affine(out::Integer) = Init(in::Integer -> Affine(in, out)) diff --git a/src/layers/shims.jl b/src/layers/shims.jl index 73e09a6e..46a31ac0 100644 --- a/src/layers/shims.jl +++ b/src/layers/shims.jl @@ -8,11 +8,10 @@ end Conv2D(size; in = 1, out = 1, stride = (1,1), init = initn) = Conv2D(param(initn(size..., in, out)), stride) -shape(c::Conv2D, in::Dims{2}) = - (map(i -> (in[i]-size(c.filter,i))÷c.stride[i]+1, (1,2))..., size(c.filter, 4)) +infer(c::Conv2D, in::Dims{4}) = + (in[1], map(i -> (in[i+1]-size(c.filter,i))÷c.stride[i]+1, (1,2))..., size(c.filter, 4)) -shape(c::Conv2D, in::Dims{3}) = - shape(c, (in[1],in[2])) +# TODO: many of these should just be functions for Pool in :[MaxPool, AvgPool].args @eval begin @@ -24,11 +23,8 @@ for Pool in :[MaxPool, AvgPool].args $Pool(size; stride = (1,1)) = $Pool(size, stride) - shape(c::$Pool, in::Dims{2}) = - map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2)) - - shape(c::$Pool, in::Dims{3}) = - (shape(c, (in[1],in[2]))..., in[3]) + infer(c::$Pool, in::Dims{4}) = + (in[1], map(i -> (in[i+1]-c.size[i])÷c.stride[i]+1, (1,2))..., in[4]) shape(c::$Pool) = nothing end @@ -40,9 +36,4 @@ end Reshape(dims::Integer...) = Reshape(dims) -function shape(r::Reshape, dims) - prod(dims) == prod(r.dims) || throw(ShapeError(r, dims)) - return r.dims -end - shape(r::Reshape, ::Void) = r.dims