replace old shape inference system
This commit is contained in:
parent
146e3f8dc3
commit
d73e962da9
@ -4,7 +4,7 @@ data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
|
||||
train = data[1:50_000]
|
||||
test = data[50_001:60_000]
|
||||
|
||||
m = Chain(
|
||||
m = @Chain(
|
||||
Input(784),
|
||||
Affine(128), relu,
|
||||
Affine( 64), relu,
|
||||
|
@ -3,17 +3,16 @@ using Flux, MXNet
|
||||
Flux.loadmx()
|
||||
|
||||
conv1 = Chain(
|
||||
Input(28,28),
|
||||
Conv2D((5,5), out = 20), tanh,
|
||||
MaxPool((2,2), stride = (2,2)))
|
||||
|
||||
conv2 = Chain(
|
||||
conv1,
|
||||
Conv2D((5,5), in = 20, out = 50), tanh,
|
||||
MaxPool((2,2), stride = (2,2)))
|
||||
|
||||
lenet = Chain(
|
||||
conv2,
|
||||
lenet = @Chain(
|
||||
Input(28,28,1),
|
||||
conv1, conv2,
|
||||
flatten,
|
||||
Affine(500), tanh,
|
||||
Affine(10), softmax)
|
||||
|
@ -1,17 +1,17 @@
|
||||
using Flux, Juno
|
||||
|
||||
conv1 = Chain(
|
||||
Reshape(28,28,1),
|
||||
Conv2D((5,5), out = 20), tanh,
|
||||
MaxPool((2,2), stride = (2,2)))
|
||||
|
||||
conv2 = Chain(
|
||||
Input(12,12,20),
|
||||
Conv2D((5,5), in = 20, out = 50), tanh,
|
||||
MaxPool((2,2), stride = (2,2)))
|
||||
|
||||
lenet = Chain(
|
||||
conv1, conv2, flatten,
|
||||
lenet = @Chain(
|
||||
Input(28,28,1),
|
||||
conv1, conv2,
|
||||
flatten,
|
||||
Affine(500), tanh,
|
||||
Affine(10), softmax)
|
||||
|
||||
|
@ -7,6 +7,7 @@ using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
|
||||
iscyclic, Constant, constant, isconstant, group, Split, splitnode,
|
||||
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
|
||||
spliceinputs, bumpinputs, Line, Frame, applylines
|
||||
using MacroTools: @q
|
||||
using Juno: Tree, Row
|
||||
|
||||
# Zero Flux Given
|
||||
@ -27,7 +28,6 @@ include("compiler/shape.jl")
|
||||
include("layers/affine.jl")
|
||||
include("layers/activation.jl")
|
||||
include("layers/recurrent.jl")
|
||||
include("layers/shape.jl")
|
||||
include("layers/chain.jl")
|
||||
include("layers/shims.jl")
|
||||
|
||||
|
@ -2,6 +2,8 @@ using DataFlow.Interpreter
|
||||
|
||||
export @shapes
|
||||
|
||||
Dims{N} = NTuple{N,Int}
|
||||
|
||||
struct Hint
|
||||
typ
|
||||
end
|
||||
@ -41,7 +43,9 @@ shapes(args...) = shapesv(args...) |> syntax |> applylines |> (x->prettify(x, li
|
||||
|
||||
infer(f, args...) = graph(f) == nothing ? nothing : gethint(shapesv(f, args...))
|
||||
|
||||
function infer(::typeof(*), a::NTuple{2}, b::NTuple{2})
|
||||
infer(::typeof(identity), x) = x
|
||||
|
||||
function infer(::typeof(*), a::Dims{2}, b::Dims{2})
|
||||
a[2] == b[1] || return nothing
|
||||
(a[1], b[2])
|
||||
end
|
||||
@ -55,3 +59,17 @@ macro shapes(ex)
|
||||
@capture(ex, f_(args__)) || error("@shapes f(args...)")
|
||||
:(shapes($(esc(f)), mapt(size, ($(map(esc, args)...),))...))
|
||||
end
|
||||
|
||||
# Shim for kicking off shape inference
|
||||
|
||||
export Input
|
||||
|
||||
struct Input{N} <: Model
|
||||
dims::Dims{N}
|
||||
end
|
||||
|
||||
Input(i...) = Input((i...,))
|
||||
|
||||
(::Input)(x) = x
|
||||
|
||||
inferchain(f::Input, xs) = (-1, f.dims...)
|
||||
|
@ -8,11 +8,14 @@ relu(x) = max(0, x)
|
||||
|
||||
back!(::typeof(relu), Δ, x) = Δ .* (x .< 0)
|
||||
|
||||
# TODO: correct behaviour with batches
|
||||
softmax(xs) = exp.(xs) ./ sum(exp.(xs))
|
||||
|
||||
# TODO: correct behaviour with batches
|
||||
flatten(xs) = reshape(xs, length(xs))
|
||||
|
||||
shape(::typeof(flatten), in) = prod(in)
|
||||
|
||||
infer(::typeof(softmax), x) = x
|
||||
infer(::typeof(tanh), x) = x
|
||||
infer(::typeof(σ), x) = x
|
||||
|
||||
infer(::typeof(flatten), x::Dims) = (x[1], prod(x[2:end])...)
|
||||
|
@ -8,3 +8,6 @@ end
|
||||
|
||||
Affine(in::Integer, out::Integer; init = initn) =
|
||||
Affine(init(in, out), init(1, out))
|
||||
|
||||
inferred(::Type{Affine}, in::Tuple{Dims{2}}, out::Integer) =
|
||||
Affine(in[1][2], out)
|
||||
|
@ -1,26 +1,11 @@
|
||||
export Chain
|
||||
|
||||
function inferchain(ms)
|
||||
chain = []
|
||||
sh = nothing
|
||||
for m in ms
|
||||
m = init(m, single(sh))
|
||||
sh = shape(m, sh)
|
||||
push!(chain, m)
|
||||
end
|
||||
return chain, sh
|
||||
end
|
||||
export Chain, @Chain
|
||||
|
||||
type Chain <: Model
|
||||
layers::Vector{Any}
|
||||
shape
|
||||
function Chain(ms...)
|
||||
ms, shape = inferchain(ms)
|
||||
return new(ms, shape)
|
||||
end
|
||||
Chain(xs...) = new([xs...])
|
||||
end
|
||||
|
||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof
|
||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
||||
@forward Chain.layers Base.start, Base.next, Base.done
|
||||
|
||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||
@ -30,6 +15,25 @@ update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
|
||||
graph(s::Chain) =
|
||||
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
|
||||
|
||||
shape(c::Chain, in) = c.shape
|
||||
|
||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||
|
||||
# Chain Macros
|
||||
|
||||
inferred(f, in, args...; kws...) = f(args...; kws...)
|
||||
|
||||
# `inferchain` allows for overriding inference behaviour for convenience.
|
||||
# For example, `infer(Affine(10, 20), nothing)` would normally return a shape
|
||||
# error, but for the interface we just ignore any errors and return (1, 20).
|
||||
inferchain(f, xs...) = infer(f, xs...)
|
||||
|
||||
macro Chain(x, xs...)
|
||||
inferconstructor(x) =
|
||||
@capture(x, f_(xs__)) ? :(inferred($(esc(f)), (shape,), $(esc.(xs)...))) : esc(x)
|
||||
@q let
|
||||
shape = nothing
|
||||
c = Chain($(esc(x)))
|
||||
$([:(shape = inferchain(c.layers[end], shape);
|
||||
push!(c, $x)) for x in inferconstructor.(xs)]...)
|
||||
c
|
||||
end
|
||||
end
|
||||
|
@ -1,47 +0,0 @@
|
||||
export Input
|
||||
|
||||
Dims{N} = NTuple{N,Int}
|
||||
|
||||
dims(d::Dims) = d
|
||||
|
||||
dims(i...) = (i...,)
|
||||
|
||||
single(i) = i
|
||||
single(i::Dims) = length(i) == 1 ? first(i) : i
|
||||
|
||||
# Shim for kicking off shape inference
|
||||
|
||||
struct ShapeError <: Exception
|
||||
layer
|
||||
shape
|
||||
end
|
||||
|
||||
struct Input{N} <: Model
|
||||
dims::Dims{N}
|
||||
end
|
||||
|
||||
Input(i...) = Input(dims(i...))
|
||||
|
||||
(::Input)(x) = x
|
||||
back!(::Input, Δ, x) = Δ
|
||||
|
||||
# Initialise placeholder
|
||||
|
||||
struct Init{F}
|
||||
f::F
|
||||
end
|
||||
|
||||
init(i::Init, input...) = i.f(input...)
|
||||
init(m, input...) = m
|
||||
|
||||
# Shape inference API
|
||||
|
||||
shape(x, in) = in
|
||||
|
||||
shape(i::Input, _) = i.dims
|
||||
|
||||
# Implementation for bundled layers
|
||||
|
||||
shape(d::Affine, _) = length(state(d.b)) # TODO: could perhaps infer this
|
||||
|
||||
Affine(out::Integer) = Init(in::Integer -> Affine(in, out))
|
@ -8,11 +8,10 @@ end
|
||||
Conv2D(size; in = 1, out = 1, stride = (1,1), init = initn) =
|
||||
Conv2D(param(initn(size..., in, out)), stride)
|
||||
|
||||
shape(c::Conv2D, in::Dims{2}) =
|
||||
(map(i -> (in[i]-size(c.filter,i))÷c.stride[i]+1, (1,2))..., size(c.filter, 4))
|
||||
infer(c::Conv2D, in::Dims{4}) =
|
||||
(in[1], map(i -> (in[i+1]-size(c.filter,i))÷c.stride[i]+1, (1,2))..., size(c.filter, 4))
|
||||
|
||||
shape(c::Conv2D, in::Dims{3}) =
|
||||
shape(c, (in[1],in[2]))
|
||||
# TODO: many of these should just be functions
|
||||
|
||||
for Pool in :[MaxPool, AvgPool].args
|
||||
@eval begin
|
||||
@ -24,11 +23,8 @@ for Pool in :[MaxPool, AvgPool].args
|
||||
$Pool(size; stride = (1,1)) =
|
||||
$Pool(size, stride)
|
||||
|
||||
shape(c::$Pool, in::Dims{2}) =
|
||||
map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2))
|
||||
|
||||
shape(c::$Pool, in::Dims{3}) =
|
||||
(shape(c, (in[1],in[2]))..., in[3])
|
||||
infer(c::$Pool, in::Dims{4}) =
|
||||
(in[1], map(i -> (in[i+1]-c.size[i])÷c.stride[i]+1, (1,2))..., in[4])
|
||||
|
||||
shape(c::$Pool) = nothing
|
||||
end
|
||||
@ -40,9 +36,4 @@ end
|
||||
|
||||
Reshape(dims::Integer...) = Reshape(dims)
|
||||
|
||||
function shape(r::Reshape, dims)
|
||||
prod(dims) == prod(r.dims) || throw(ShapeError(r, dims))
|
||||
return r.dims
|
||||
end
|
||||
|
||||
shape(r::Reshape, ::Void) = r.dims
|
||||
|
Loading…
Reference in New Issue
Block a user