diff --git a/examples/MNIST.jl b/examples/MNIST.jl index d6f705b7..ce29e58a 100644 --- a/examples/MNIST.jl +++ b/examples/MNIST.jl @@ -1,6 +1,4 @@ -using Flux, MNIST, Flow, MacroTools -import Flux.MX: mxnet -import Flux: back!, update!, graph +using Flux, MNIST @time begin const data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000] @@ -11,9 +9,9 @@ end m = Chain( Input(784), - Dense(784, 128), relu, - Dense(128, 64), relu, - Dense(64, 10), softmax) + Dense(128), relu, + Dense( 64), relu, + Dense( 10), softmax) model = mxnet(m, 784) diff --git a/src/Flux.jl b/src/Flux.jl index 8a0e6c67..f4121a35 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -10,12 +10,13 @@ include("utils.jl") include("compiler/diff.jl") include("compiler/code.jl") +include("layers/dense.jl") +include("layers/shape.jl") +include("layers/chain.jl") + include("cost.jl") include("activation.jl") -include("layers/input.jl") -include("layers/dense.jl") -include("layers/sequence.jl") -include("backend/mxnet/mxnet.jl") +include("backend/backend.jl") end # module diff --git a/src/activation.jl b/src/activation.jl index 37eddb8f..52c62ca3 100644 --- a/src/activation.jl +++ b/src/activation.jl @@ -8,4 +8,4 @@ relu(x) = max(0, x) back(::typeof(relu), Δ, x) = Δ .* (x .< 0) -softmax(x) = error("not implemented") +softmax(xs) = exp.(xs) ./ sum(exp.(xs)) diff --git a/src/backend/backend.jl b/src/backend/backend.jl new file mode 100644 index 00000000..0715fd8b --- /dev/null +++ b/src/backend/backend.jl @@ -0,0 +1,7 @@ +# TODO: load backends lazily + +include("mxnet/mxnet.jl") + +using .MX + +export mxnet diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 57b3fc09..b3728621 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -32,3 +32,6 @@ node(::typeof(+), args...) = mx.broadcast_plus(args...) node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid) node(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu) node(::typeof(softmax), xs) = mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1))) + +graph(vars, ::Input, x) = x +graph(vars, ::Input, x, args...) = error("too many arguments to Input") diff --git a/src/backend/mxnet/mxnet.jl b/src/backend/mxnet/mxnet.jl index ae532f85..147a1aa7 100644 --- a/src/backend/mxnet/mxnet.jl +++ b/src/backend/mxnet/mxnet.jl @@ -2,6 +2,8 @@ module MX using MXNet, Flow, ..Flux +export mxnet + include("graph.jl") include("model.jl") diff --git a/src/layers/chain.jl b/src/layers/chain.jl new file mode 100644 index 00000000..906e19c8 --- /dev/null +++ b/src/layers/chain.jl @@ -0,0 +1,30 @@ +export Chain + +function inferchain(ms) + chain = [] + sh = nothing + for m in ms + m = init(m, single(sh)) + sh = shape(m, sh) + push!(chain, m) + end + return chain, sh +end + +type Chain <: Model + layers::Vector{Any} + shape + function Chain(ms...) + ms, shape = inferchain(ms) + return new(ms, shape) + end +end + +@forward Chain.layers Base.getindex, Base.first, Base.last + +(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) +back!(s::Chain, ∇) = foldr((m, ∇) -> back!(m, ∇), ∇, s.layers) +update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers) + +graph(s::Chain) = + foldl((v, m) -> vertex(m, v), constant(ModelInput(1)), s.layers) diff --git a/src/layers/dense.jl b/src/layers/dense.jl index 1577ca2d..1a3fa302 100644 --- a/src/layers/dense.jl +++ b/src/layers/dense.jl @@ -1,5 +1,7 @@ export Dense +# TODO: type hints for parameters + @model type Dense W b @@ -9,7 +11,8 @@ end Dense(in::Integer, out::Integer; init = initn) = Dense(init(out, in), init(out)) -Base.show(io::IO, ::Dense) = print(io, "Flux.Dense(...)") +Base.show(io::IO, d::Dense) = + print(io, "Flux.Dense($(size(d.W.x,2)),$(size(d.W.x,1)))") @model type Sigmoid layer::Model @@ -18,3 +21,13 @@ end Sigmoid(in::Integer, out::Integer; init = randn) = Sigmoid(Dense(in, out, init = init)) + +# @model type Recurrent +# Wxh; Whh; Bh +# Wxy; Why; By +# +# function (x) +# hidden = σ( Wxh*x + Whh*hidden + Bh ) +# y = σ( Wxy*x + Why*hidden + By ) +# end +# end diff --git a/src/layers/input.jl b/src/layers/input.jl deleted file mode 100644 index 8b060ed5..00000000 --- a/src/layers/input.jl +++ /dev/null @@ -1,26 +0,0 @@ -export Input - -typealias Dims{N} NTuple{N,Int} - -dims(d::Dims) = d - -dims(i...) = (i...,) - -type Input{N} <: Model - dims::Dims{N} -end - -Input(i...) = Input(dims(i...)) - -(::Input)(x) = x -back!(::Input, ∇, x) = ∇ - -shape(i::Input) = i.dims - -# Initialise placeholder - -type Init{F} - f::F -end - -(f::Init)(args...) = f.f(args...) diff --git a/src/layers/sequence.jl b/src/layers/sequence.jl deleted file mode 100644 index 8297af77..00000000 --- a/src/layers/sequence.jl +++ /dev/null @@ -1,23 +0,0 @@ -export Sequence - -type Sequence <: Model - layers::Vector{Model} -end - -Sequence() = Sequence([]) - -@forward Sequence.layers Base.getindex, Base.first, Base.last - -Base.push!(s::Sequence, m::Model) = push!(s.layers, m) - -Base.push!(s::Sequence, f::Init) = push!(s, f(shape(last(s)))) - -function Sequence(ms...) - s = Sequence() - foreach(m -> push!(s, m), ms) - return s -end - -(s::Sequence)(x) = foldl((x, m) -> m(x), x, s.layers) -back!(s::Sequence, ∇) = foldr((m, ∇) -> back!(m, ∇), ∇, s.layers) -update!(s::Sequence, η) = foreach(l -> update!(l, η), s.layers) diff --git a/src/layers/shape.jl b/src/layers/shape.jl new file mode 100644 index 00000000..59afcc76 --- /dev/null +++ b/src/layers/shape.jl @@ -0,0 +1,42 @@ +export Input + +# Shim for kicking off shape inference + +typealias Dims{N} NTuple{N,Int} + +dims(d::Dims) = d + +dims(i...) = (i...,) + +single(i) = i +single(i::Dims) = length(i) == 1 ? first(i) : i + +type Input{N} <: Model + dims::Dims{N} +end + +Input(i...) = Input(dims(i...)) + +(::Input)(x) = x +back!(::Input, ∇, x) = ∇ + +# Initialise placeholder + +type Init{F} + f::F +end + +init(i::Init, input...) = i.f(input...) +init(m, input...) = m + +# Shape inference API + +shape(x, in) = in + +shape(i::Input, _) = i.dims + +# Implementation for bundled layers + +shape(d::Dense, _) = length(state(d.b)) # TODO: could perhaps infer this + +Dense(out::Integer) = Init(in::Integer -> Dense(in, out)) diff --git a/src/utils.jl b/src/utils.jl index ce031586..effa5e72 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,19 +5,20 @@ const AArray = AbstractArray onehot(label, labels) = [i == label for i in labels] onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))] -function train!(m::Model, train, test = []; epoch = 1, batch = 10, η = 0.1) initn(dims...) = randn(dims...)/100 +function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1) i = 0 ∇ = zeros(length(train[1][2])) for _ in 1:epoch - for (x, y) in shuffle!(train) + for (x, y) in train i += 1 - err = mse!(∇, m(x), y) + pred = m(x) + any(isnan, pred) && error("NaN") + err = mse!(∇, pred, y) back!(m, ∇, x) - i % batch == 0 && update!(m, η/batch) + i % batch == 0 && (update!(m, η); @show accuracy(m, test)) end - @show accuracy(m, test) end return m end