diff --git a/examples/mnist-conv.jl b/examples/mnist-conv.jl new file mode 100644 index 00000000..124acf66 --- /dev/null +++ b/examples/mnist-conv.jl @@ -0,0 +1,47 @@ +using Flux, MXNet + +# Flux aims to provide high-level APIs that work well across backends, but in +# some cases you may want to take advantage of features specific to a given +# backend (or alternatively, Flux may simply not have an implementation of that +# feature yet). In these cases it's easy to "drop down" and use the backend's +# API directly, where appropriate. + +# In this example, both things are happening; firstly, Flux doesn't yet support +# ConvNets in the pure-Julia backend, but this is invisible thanks to the use of +# a simple "shim" type, `Conv`. This is provided by the library but could easily +# have been user-defined. + +# Secondly, we want to take advantage of MXNet.jl's training process and +# optimisers. We can simply call `mx.FeedForward` exactly as we would on a +# regular MXNet model, and the rest of the process is trivial. + +conv1 = Chain( + Input(28,28), + Conv((5,5),20), tanh, + MaxPool((2,2), stride = (2,2))) + +conv2 = Chain( + conv1, + Conv((5,5),50), tanh, + MaxPool((2,2), stride = (2,2))) + +lenet = Chain( + conv2, + flatten, + Dense(500), tanh, + Dense(10), softmax) + +#-------------------------------------------------------------------------------- + +# Now we can continue exactly as in plain MXNet, following +# https://github.com/dmlc/MXNet.jl/blob/master/examples/mnist/lenet.jl + +batch_size = 100 +include(Pkg.dir("MXNet", "examples", "mnist", "mnist-data.jl")) +train_provider, eval_provider = get_mnist_providers(batch_size; flat=false) + +model = mx.FeedForward(lenet, context = mx.gpu()) + +optimizer = mx.SGD(lr=0.05, momentum=0.9, weight_decay=0.00001) + +mx.fit(model, optimizer, train_provider, n_epoch=1, eval_data=eval_provider) diff --git a/src/Flux.jl b/src/Flux.jl index c70ecc08..29113fc0 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -15,6 +15,7 @@ include("compiler/loops.jl") include("layers/dense.jl") include("layers/shape.jl") include("layers/chain.jl") +include("layers/shims.jl") include("cost.jl") include("activation.jl") diff --git a/src/activation.jl b/src/activation.jl index d6f7d275..450de921 100644 --- a/src/activation.jl +++ b/src/activation.jl @@ -1,4 +1,4 @@ -export σ, relu, softmax +export σ, relu, softmax, flatten σ(x) = 1 ./ (1 .+ exp.(-x)) @@ -9,3 +9,7 @@ relu(x) = max(0, x) back!(::typeof(relu), Δ, x) = Δ .* (x .< 0) softmax(xs) = exp.(xs) ./ sum(exp.(xs)) + +flatten(xs) = reshape(xs, length(xs)) + +shape(::typeof(flatten), in) = prod(in) diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index b3728621..b3eb3270 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -25,13 +25,49 @@ function graph(vars, model::Model, args...) end |> value end +type SoftmaxOutput + name::Symbol +end + +function rewrite_softmax(model, name) + model == softmax && return SoftmaxOutput(name) + g = Flux.graph(model) + (g == nothing || value(g) ≠ softmax || Flow.nin(g) ≠ 1) && error("mx.FeedForward models must end with `softmax`") + return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1])) +end + # Built-in implemenations node(::typeof(*), args...) = mx.dot(args...) node(::typeof(+), args...) = mx.broadcast_plus(args...) node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid) node(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu) -node(::typeof(softmax), xs) = mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1))) +node(::typeof(tanh), x) = mx.Activation(data = x, act_type=:tanh) +node(::typeof(flatten), x) = mx.Flatten(data = x) + +node(::typeof(softmax), xs) = + mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1))) + +node(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(data = xs, name = s.name) graph(vars, ::Input, x) = x -graph(vars, ::Input, x, args...) = error("too many arguments to Input") + +graph(vars, c::Conv, x) = + mx.Convolution(data = x, + kernel = c.size, + num_filter = c.features, + stride = c.stride) + +graph(vars, p::MaxPool, x) = + mx.Pooling(data = x, + pool_type = :max, + kernel = p.size, + stride = p.stride) + +# TODO: fix the initialisation issue +graph(vars, d::Dense, x) = + mx.FullyConnected(data = x, + num_hidden = size(d.W.x, 1), + # weight = graph(vars, d.W), + # bias = graph(vars, d.b) + ) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 60634634..4ff6739e 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -47,9 +47,14 @@ function load!(model::MXModel) return model end -function mxnet(model::Model, input) +function mxgraph(model, input) vars = Dict{Symbol,Any}() - node = graph(vars, model, mx.Variable(:input)) + node = graph(vars, model, mx.Variable(input)) + return node, vars +end + +function mxnet(model::Model, input) + node, vars = mxgraph(model, :input) args = merge(mxargs(vars), Dict(:input => mx.zeros(mxdims(input)))) grads = mxgrads(args) model = MXModel(model, vars, grads, @@ -81,3 +86,11 @@ function Flux.update!(model::MXModel, η) end return model end + +# MX FeedForward interface + +function mx.FeedForward(model::Model; input = :data, label = :softmax, context = mx.cpu()) + model = rewrite_softmax(model, label) + node, _ = mxgraph(model, input) + return mx.FeedForward(node, context = context) +end diff --git a/src/layers/chain.jl b/src/layers/chain.jl index 74f4a757..c6829fcd 100644 --- a/src/layers/chain.jl +++ b/src/layers/chain.jl @@ -34,3 +34,5 @@ end graph(s::Chain) = foldl((v, m) -> vertex(m, v), constant(ModelInput(1)), s.layers) + +shape(c::Chain, in) = c.shape diff --git a/src/layers/shape.jl b/src/layers/shape.jl index cc7a2718..07fc708a 100644 --- a/src/layers/shape.jl +++ b/src/layers/shape.jl @@ -11,6 +11,11 @@ single(i::Dims) = length(i) == 1 ? first(i) : i # Shim for kicking off shape inference +type ShapeError <: Exception + layer + shape +end + type Input{N} <: Model dims::Dims{N} end diff --git a/src/layers/shims.jl b/src/layers/shims.jl new file mode 100644 index 00000000..e6d6917b --- /dev/null +++ b/src/layers/shims.jl @@ -0,0 +1,32 @@ +export Conv, MaxPool + +type Conv <: Model + size::Dims{2} + features::Int + stride::Dims{2} +end + +Conv(size, features; stride = (1,1)) = + Conv(size, features, stride) + +shape(c::Conv, in::Dims{2}) = + (map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2))..., c.features) + +shape(c::Conv, in::Dims{3}) = + shape(c, (in[1],in[2])) + +type MaxPool <: Model + size::Dims{2} + stride::Dims{2} +end + +MaxPool(size; stride = (1,1)) = + MaxPool(size, stride) + +shape(c::MaxPool, in::Dims{2}) = + map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2)) + +shape(c::MaxPool, in::Dims{3}) = + (shape(c, (in[1],in[2]))..., in[3]) + +shape(c::MaxPool, in) = throw(ShapeError(c, in))