basic convnet example working

This commit is contained in:
Mike J Innes 2016-09-06 18:03:39 +01:00
parent 205e1215d6
commit 8961b4c10f
8 changed files with 145 additions and 5 deletions

47
examples/mnist-conv.jl Normal file
View File

@ -0,0 +1,47 @@
using Flux, MXNet
# Flux aims to provide high-level APIs that work well across backends, but in
# some cases you may want to take advantage of features specific to a given
# backend (or alternatively, Flux may simply not have an implementation of that
# feature yet). In these cases it's easy to "drop down" and use the backend's
# API directly, where appropriate.
# In this example, both things are happening; firstly, Flux doesn't yet support
# ConvNets in the pure-Julia backend, but this is invisible thanks to the use of
# a simple "shim" type, `Conv`. This is provided by the library but could easily
# have been user-defined.
# Secondly, we want to take advantage of MXNet.jl's training process and
# optimisers. We can simply call `mx.FeedForward` exactly as we would on a
# regular MXNet model, and the rest of the process is trivial.
conv1 = Chain(
Input(28,28),
Conv((5,5),20), tanh,
MaxPool((2,2), stride = (2,2)))
conv2 = Chain(
conv1,
Conv((5,5),50), tanh,
MaxPool((2,2), stride = (2,2)))
lenet = Chain(
conv2,
flatten,
Dense(500), tanh,
Dense(10), softmax)
#--------------------------------------------------------------------------------
# Now we can continue exactly as in plain MXNet, following
# https://github.com/dmlc/MXNet.jl/blob/master/examples/mnist/lenet.jl
batch_size = 100
include(Pkg.dir("MXNet", "examples", "mnist", "mnist-data.jl"))
train_provider, eval_provider = get_mnist_providers(batch_size; flat=false)
model = mx.FeedForward(lenet, context = mx.gpu())
optimizer = mx.SGD(lr=0.05, momentum=0.9, weight_decay=0.00001)
mx.fit(model, optimizer, train_provider, n_epoch=1, eval_data=eval_provider)

View File

@ -15,6 +15,7 @@ include("compiler/loops.jl")
include("layers/dense.jl")
include("layers/shape.jl")
include("layers/chain.jl")
include("layers/shims.jl")
include("cost.jl")
include("activation.jl")

View File

@ -1,4 +1,4 @@
export σ, relu, softmax
export σ, relu, softmax, flatten
σ(x) = 1 ./ (1 .+ exp.(-x))
@ -9,3 +9,7 @@ relu(x) = max(0, x)
back!(::typeof(relu), Δ, x) = Δ .* (x .< 0)
softmax(xs) = exp.(xs) ./ sum(exp.(xs))
flatten(xs) = reshape(xs, length(xs))
shape(::typeof(flatten), in) = prod(in)

View File

@ -25,13 +25,49 @@ function graph(vars, model::Model, args...)
end |> value
end
type SoftmaxOutput
name::Symbol
end
function rewrite_softmax(model, name)
model == softmax && return SoftmaxOutput(name)
g = Flux.graph(model)
(g == nothing || value(g) softmax || Flow.nin(g) 1) && error("mx.FeedForward models must end with `softmax`")
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
end
# Built-in implemenations
node(::typeof(*), args...) = mx.dot(args...)
node(::typeof(+), args...) = mx.broadcast_plus(args...)
node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
node(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu)
node(::typeof(softmax), xs) = mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
node(::typeof(tanh), x) = mx.Activation(data = x, act_type=:tanh)
node(::typeof(flatten), x) = mx.Flatten(data = x)
node(::typeof(softmax), xs) =
mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
node(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(data = xs, name = s.name)
graph(vars, ::Input, x) = x
graph(vars, ::Input, x, args...) = error("too many arguments to Input")
graph(vars, c::Conv, x) =
mx.Convolution(data = x,
kernel = c.size,
num_filter = c.features,
stride = c.stride)
graph(vars, p::MaxPool, x) =
mx.Pooling(data = x,
pool_type = :max,
kernel = p.size,
stride = p.stride)
# TODO: fix the initialisation issue
graph(vars, d::Dense, x) =
mx.FullyConnected(data = x,
num_hidden = size(d.W.x, 1),
# weight = graph(vars, d.W),
# bias = graph(vars, d.b)
)

View File

@ -47,9 +47,14 @@ function load!(model::MXModel)
return model
end
function mxnet(model::Model, input)
function mxgraph(model, input)
vars = Dict{Symbol,Any}()
node = graph(vars, model, mx.Variable(:input))
node = graph(vars, model, mx.Variable(input))
return node, vars
end
function mxnet(model::Model, input)
node, vars = mxgraph(model, :input)
args = merge(mxargs(vars), Dict(:input => mx.zeros(mxdims(input))))
grads = mxgrads(args)
model = MXModel(model, vars, grads,
@ -81,3 +86,11 @@ function Flux.update!(model::MXModel, η)
end
return model
end
# MX FeedForward interface
function mx.FeedForward(model::Model; input = :data, label = :softmax, context = mx.cpu())
model = rewrite_softmax(model, label)
node, _ = mxgraph(model, input)
return mx.FeedForward(node, context = context)
end

View File

@ -34,3 +34,5 @@ end
graph(s::Chain) =
foldl((v, m) -> vertex(m, v), constant(ModelInput(1)), s.layers)
shape(c::Chain, in) = c.shape

View File

@ -11,6 +11,11 @@ single(i::Dims) = length(i) == 1 ? first(i) : i
# Shim for kicking off shape inference
type ShapeError <: Exception
layer
shape
end
type Input{N} <: Model
dims::Dims{N}
end

32
src/layers/shims.jl Normal file
View File

@ -0,0 +1,32 @@
export Conv, MaxPool
type Conv <: Model
size::Dims{2}
features::Int
stride::Dims{2}
end
Conv(size, features; stride = (1,1)) =
Conv(size, features, stride)
shape(c::Conv, in::Dims{2}) =
(map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2))..., c.features)
shape(c::Conv, in::Dims{3}) =
shape(c, (in[1],in[2]))
type MaxPool <: Model
size::Dims{2}
stride::Dims{2}
end
MaxPool(size; stride = (1,1)) =
MaxPool(size, stride)
shape(c::MaxPool, in::Dims{2}) =
map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2))
shape(c::MaxPool, in::Dims{3}) =
(shape(c, (in[1],in[2]))..., in[3])
shape(c::MaxPool, in) = throw(ShapeError(c, in))