replace old shape inference system
This commit is contained in:
parent
146e3f8dc3
commit
d73e962da9
@ -4,7 +4,7 @@ data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
|
|||||||
train = data[1:50_000]
|
train = data[1:50_000]
|
||||||
test = data[50_001:60_000]
|
test = data[50_001:60_000]
|
||||||
|
|
||||||
m = Chain(
|
m = @Chain(
|
||||||
Input(784),
|
Input(784),
|
||||||
Affine(128), relu,
|
Affine(128), relu,
|
||||||
Affine( 64), relu,
|
Affine( 64), relu,
|
||||||
|
@ -3,17 +3,16 @@ using Flux, MXNet
|
|||||||
Flux.loadmx()
|
Flux.loadmx()
|
||||||
|
|
||||||
conv1 = Chain(
|
conv1 = Chain(
|
||||||
Input(28,28),
|
|
||||||
Conv2D((5,5), out = 20), tanh,
|
Conv2D((5,5), out = 20), tanh,
|
||||||
MaxPool((2,2), stride = (2,2)))
|
MaxPool((2,2), stride = (2,2)))
|
||||||
|
|
||||||
conv2 = Chain(
|
conv2 = Chain(
|
||||||
conv1,
|
|
||||||
Conv2D((5,5), in = 20, out = 50), tanh,
|
Conv2D((5,5), in = 20, out = 50), tanh,
|
||||||
MaxPool((2,2), stride = (2,2)))
|
MaxPool((2,2), stride = (2,2)))
|
||||||
|
|
||||||
lenet = Chain(
|
lenet = @Chain(
|
||||||
conv2,
|
Input(28,28,1),
|
||||||
|
conv1, conv2,
|
||||||
flatten,
|
flatten,
|
||||||
Affine(500), tanh,
|
Affine(500), tanh,
|
||||||
Affine(10), softmax)
|
Affine(10), softmax)
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
using Flux, Juno
|
using Flux, Juno
|
||||||
|
|
||||||
conv1 = Chain(
|
conv1 = Chain(
|
||||||
Reshape(28,28,1),
|
|
||||||
Conv2D((5,5), out = 20), tanh,
|
Conv2D((5,5), out = 20), tanh,
|
||||||
MaxPool((2,2), stride = (2,2)))
|
MaxPool((2,2), stride = (2,2)))
|
||||||
|
|
||||||
conv2 = Chain(
|
conv2 = Chain(
|
||||||
Input(12,12,20),
|
|
||||||
Conv2D((5,5), in = 20, out = 50), tanh,
|
Conv2D((5,5), in = 20, out = 50), tanh,
|
||||||
MaxPool((2,2), stride = (2,2)))
|
MaxPool((2,2), stride = (2,2)))
|
||||||
|
|
||||||
lenet = Chain(
|
lenet = @Chain(
|
||||||
conv1, conv2, flatten,
|
Input(28,28,1),
|
||||||
|
conv1, conv2,
|
||||||
|
flatten,
|
||||||
Affine(500), tanh,
|
Affine(500), tanh,
|
||||||
Affine(10), softmax)
|
Affine(10), softmax)
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
|
|||||||
iscyclic, Constant, constant, isconstant, group, Split, splitnode,
|
iscyclic, Constant, constant, isconstant, group, Split, splitnode,
|
||||||
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
|
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
|
||||||
spliceinputs, bumpinputs, Line, Frame, applylines
|
spliceinputs, bumpinputs, Line, Frame, applylines
|
||||||
|
using MacroTools: @q
|
||||||
using Juno: Tree, Row
|
using Juno: Tree, Row
|
||||||
|
|
||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
@ -27,7 +28,6 @@ include("compiler/shape.jl")
|
|||||||
include("layers/affine.jl")
|
include("layers/affine.jl")
|
||||||
include("layers/activation.jl")
|
include("layers/activation.jl")
|
||||||
include("layers/recurrent.jl")
|
include("layers/recurrent.jl")
|
||||||
include("layers/shape.jl")
|
|
||||||
include("layers/chain.jl")
|
include("layers/chain.jl")
|
||||||
include("layers/shims.jl")
|
include("layers/shims.jl")
|
||||||
|
|
||||||
|
@ -2,6 +2,8 @@ using DataFlow.Interpreter
|
|||||||
|
|
||||||
export @shapes
|
export @shapes
|
||||||
|
|
||||||
|
Dims{N} = NTuple{N,Int}
|
||||||
|
|
||||||
struct Hint
|
struct Hint
|
||||||
typ
|
typ
|
||||||
end
|
end
|
||||||
@ -41,7 +43,9 @@ shapes(args...) = shapesv(args...) |> syntax |> applylines |> (x->prettify(x, li
|
|||||||
|
|
||||||
infer(f, args...) = graph(f) == nothing ? nothing : gethint(shapesv(f, args...))
|
infer(f, args...) = graph(f) == nothing ? nothing : gethint(shapesv(f, args...))
|
||||||
|
|
||||||
function infer(::typeof(*), a::NTuple{2}, b::NTuple{2})
|
infer(::typeof(identity), x) = x
|
||||||
|
|
||||||
|
function infer(::typeof(*), a::Dims{2}, b::Dims{2})
|
||||||
a[2] == b[1] || return nothing
|
a[2] == b[1] || return nothing
|
||||||
(a[1], b[2])
|
(a[1], b[2])
|
||||||
end
|
end
|
||||||
@ -55,3 +59,17 @@ macro shapes(ex)
|
|||||||
@capture(ex, f_(args__)) || error("@shapes f(args...)")
|
@capture(ex, f_(args__)) || error("@shapes f(args...)")
|
||||||
:(shapes($(esc(f)), mapt(size, ($(map(esc, args)...),))...))
|
:(shapes($(esc(f)), mapt(size, ($(map(esc, args)...),))...))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Shim for kicking off shape inference
|
||||||
|
|
||||||
|
export Input
|
||||||
|
|
||||||
|
struct Input{N} <: Model
|
||||||
|
dims::Dims{N}
|
||||||
|
end
|
||||||
|
|
||||||
|
Input(i...) = Input((i...,))
|
||||||
|
|
||||||
|
(::Input)(x) = x
|
||||||
|
|
||||||
|
inferchain(f::Input, xs) = (-1, f.dims...)
|
||||||
|
@ -8,11 +8,14 @@ relu(x) = max(0, x)
|
|||||||
|
|
||||||
back!(::typeof(relu), Δ, x) = Δ .* (x .< 0)
|
back!(::typeof(relu), Δ, x) = Δ .* (x .< 0)
|
||||||
|
|
||||||
|
# TODO: correct behaviour with batches
|
||||||
softmax(xs) = exp.(xs) ./ sum(exp.(xs))
|
softmax(xs) = exp.(xs) ./ sum(exp.(xs))
|
||||||
|
|
||||||
|
# TODO: correct behaviour with batches
|
||||||
flatten(xs) = reshape(xs, length(xs))
|
flatten(xs) = reshape(xs, length(xs))
|
||||||
|
|
||||||
shape(::typeof(flatten), in) = prod(in)
|
|
||||||
|
|
||||||
infer(::typeof(softmax), x) = x
|
infer(::typeof(softmax), x) = x
|
||||||
|
infer(::typeof(tanh), x) = x
|
||||||
infer(::typeof(σ), x) = x
|
infer(::typeof(σ), x) = x
|
||||||
|
|
||||||
|
infer(::typeof(flatten), x::Dims) = (x[1], prod(x[2:end])...)
|
||||||
|
@ -8,3 +8,6 @@ end
|
|||||||
|
|
||||||
Affine(in::Integer, out::Integer; init = initn) =
|
Affine(in::Integer, out::Integer; init = initn) =
|
||||||
Affine(init(in, out), init(1, out))
|
Affine(init(in, out), init(1, out))
|
||||||
|
|
||||||
|
inferred(::Type{Affine}, in::Tuple{Dims{2}}, out::Integer) =
|
||||||
|
Affine(in[1][2], out)
|
||||||
|
@ -1,26 +1,11 @@
|
|||||||
export Chain
|
export Chain, @Chain
|
||||||
|
|
||||||
function inferchain(ms)
|
|
||||||
chain = []
|
|
||||||
sh = nothing
|
|
||||||
for m in ms
|
|
||||||
m = init(m, single(sh))
|
|
||||||
sh = shape(m, sh)
|
|
||||||
push!(chain, m)
|
|
||||||
end
|
|
||||||
return chain, sh
|
|
||||||
end
|
|
||||||
|
|
||||||
type Chain <: Model
|
type Chain <: Model
|
||||||
layers::Vector{Any}
|
layers::Vector{Any}
|
||||||
shape
|
Chain(xs...) = new([xs...])
|
||||||
function Chain(ms...)
|
|
||||||
ms, shape = inferchain(ms)
|
|
||||||
return new(ms, shape)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof
|
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
||||||
@forward Chain.layers Base.start, Base.next, Base.done
|
@forward Chain.layers Base.start, Base.next, Base.done
|
||||||
|
|
||||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||||
@ -30,6 +15,25 @@ update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
|
|||||||
graph(s::Chain) =
|
graph(s::Chain) =
|
||||||
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
|
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
|
||||||
|
|
||||||
shape(c::Chain, in) = c.shape
|
|
||||||
|
|
||||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||||
|
|
||||||
|
# Chain Macros
|
||||||
|
|
||||||
|
inferred(f, in, args...; kws...) = f(args...; kws...)
|
||||||
|
|
||||||
|
# `inferchain` allows for overriding inference behaviour for convenience.
|
||||||
|
# For example, `infer(Affine(10, 20), nothing)` would normally return a shape
|
||||||
|
# error, but for the interface we just ignore any errors and return (1, 20).
|
||||||
|
inferchain(f, xs...) = infer(f, xs...)
|
||||||
|
|
||||||
|
macro Chain(x, xs...)
|
||||||
|
inferconstructor(x) =
|
||||||
|
@capture(x, f_(xs__)) ? :(inferred($(esc(f)), (shape,), $(esc.(xs)...))) : esc(x)
|
||||||
|
@q let
|
||||||
|
shape = nothing
|
||||||
|
c = Chain($(esc(x)))
|
||||||
|
$([:(shape = inferchain(c.layers[end], shape);
|
||||||
|
push!(c, $x)) for x in inferconstructor.(xs)]...)
|
||||||
|
c
|
||||||
|
end
|
||||||
|
end
|
||||||
|
@ -1,47 +0,0 @@
|
|||||||
export Input
|
|
||||||
|
|
||||||
Dims{N} = NTuple{N,Int}
|
|
||||||
|
|
||||||
dims(d::Dims) = d
|
|
||||||
|
|
||||||
dims(i...) = (i...,)
|
|
||||||
|
|
||||||
single(i) = i
|
|
||||||
single(i::Dims) = length(i) == 1 ? first(i) : i
|
|
||||||
|
|
||||||
# Shim for kicking off shape inference
|
|
||||||
|
|
||||||
struct ShapeError <: Exception
|
|
||||||
layer
|
|
||||||
shape
|
|
||||||
end
|
|
||||||
|
|
||||||
struct Input{N} <: Model
|
|
||||||
dims::Dims{N}
|
|
||||||
end
|
|
||||||
|
|
||||||
Input(i...) = Input(dims(i...))
|
|
||||||
|
|
||||||
(::Input)(x) = x
|
|
||||||
back!(::Input, Δ, x) = Δ
|
|
||||||
|
|
||||||
# Initialise placeholder
|
|
||||||
|
|
||||||
struct Init{F}
|
|
||||||
f::F
|
|
||||||
end
|
|
||||||
|
|
||||||
init(i::Init, input...) = i.f(input...)
|
|
||||||
init(m, input...) = m
|
|
||||||
|
|
||||||
# Shape inference API
|
|
||||||
|
|
||||||
shape(x, in) = in
|
|
||||||
|
|
||||||
shape(i::Input, _) = i.dims
|
|
||||||
|
|
||||||
# Implementation for bundled layers
|
|
||||||
|
|
||||||
shape(d::Affine, _) = length(state(d.b)) # TODO: could perhaps infer this
|
|
||||||
|
|
||||||
Affine(out::Integer) = Init(in::Integer -> Affine(in, out))
|
|
@ -8,11 +8,10 @@ end
|
|||||||
Conv2D(size; in = 1, out = 1, stride = (1,1), init = initn) =
|
Conv2D(size; in = 1, out = 1, stride = (1,1), init = initn) =
|
||||||
Conv2D(param(initn(size..., in, out)), stride)
|
Conv2D(param(initn(size..., in, out)), stride)
|
||||||
|
|
||||||
shape(c::Conv2D, in::Dims{2}) =
|
infer(c::Conv2D, in::Dims{4}) =
|
||||||
(map(i -> (in[i]-size(c.filter,i))÷c.stride[i]+1, (1,2))..., size(c.filter, 4))
|
(in[1], map(i -> (in[i+1]-size(c.filter,i))÷c.stride[i]+1, (1,2))..., size(c.filter, 4))
|
||||||
|
|
||||||
shape(c::Conv2D, in::Dims{3}) =
|
# TODO: many of these should just be functions
|
||||||
shape(c, (in[1],in[2]))
|
|
||||||
|
|
||||||
for Pool in :[MaxPool, AvgPool].args
|
for Pool in :[MaxPool, AvgPool].args
|
||||||
@eval begin
|
@eval begin
|
||||||
@ -24,11 +23,8 @@ for Pool in :[MaxPool, AvgPool].args
|
|||||||
$Pool(size; stride = (1,1)) =
|
$Pool(size; stride = (1,1)) =
|
||||||
$Pool(size, stride)
|
$Pool(size, stride)
|
||||||
|
|
||||||
shape(c::$Pool, in::Dims{2}) =
|
infer(c::$Pool, in::Dims{4}) =
|
||||||
map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2))
|
(in[1], map(i -> (in[i+1]-c.size[i])÷c.stride[i]+1, (1,2))..., in[4])
|
||||||
|
|
||||||
shape(c::$Pool, in::Dims{3}) =
|
|
||||||
(shape(c, (in[1],in[2]))..., in[3])
|
|
||||||
|
|
||||||
shape(c::$Pool) = nothing
|
shape(c::$Pool) = nothing
|
||||||
end
|
end
|
||||||
@ -40,9 +36,4 @@ end
|
|||||||
|
|
||||||
Reshape(dims::Integer...) = Reshape(dims)
|
Reshape(dims::Integer...) = Reshape(dims)
|
||||||
|
|
||||||
function shape(r::Reshape, dims)
|
|
||||||
prod(dims) == prod(r.dims) || throw(ShapeError(r, dims))
|
|
||||||
return r.dims
|
|
||||||
end
|
|
||||||
|
|
||||||
shape(r::Reshape, ::Void) = r.dims
|
shape(r::Reshape, ::Void) = r.dims
|
||||||
|
Loading…
Reference in New Issue
Block a user