basic convnet example working
This commit is contained in:
parent
205e1215d6
commit
8961b4c10f
|
@ -0,0 +1,47 @@
|
|||
using Flux, MXNet
|
||||
|
||||
# Flux aims to provide high-level APIs that work well across backends, but in
|
||||
# some cases you may want to take advantage of features specific to a given
|
||||
# backend (or alternatively, Flux may simply not have an implementation of that
|
||||
# feature yet). In these cases it's easy to "drop down" and use the backend's
|
||||
# API directly, where appropriate.
|
||||
|
||||
# In this example, both things are happening; firstly, Flux doesn't yet support
|
||||
# ConvNets in the pure-Julia backend, but this is invisible thanks to the use of
|
||||
# a simple "shim" type, `Conv`. This is provided by the library but could easily
|
||||
# have been user-defined.
|
||||
|
||||
# Secondly, we want to take advantage of MXNet.jl's training process and
|
||||
# optimisers. We can simply call `mx.FeedForward` exactly as we would on a
|
||||
# regular MXNet model, and the rest of the process is trivial.
|
||||
|
||||
conv1 = Chain(
|
||||
Input(28,28),
|
||||
Conv((5,5),20), tanh,
|
||||
MaxPool((2,2), stride = (2,2)))
|
||||
|
||||
conv2 = Chain(
|
||||
conv1,
|
||||
Conv((5,5),50), tanh,
|
||||
MaxPool((2,2), stride = (2,2)))
|
||||
|
||||
lenet = Chain(
|
||||
conv2,
|
||||
flatten,
|
||||
Dense(500), tanh,
|
||||
Dense(10), softmax)
|
||||
|
||||
#--------------------------------------------------------------------------------
|
||||
|
||||
# Now we can continue exactly as in plain MXNet, following
|
||||
# https://github.com/dmlc/MXNet.jl/blob/master/examples/mnist/lenet.jl
|
||||
|
||||
batch_size = 100
|
||||
include(Pkg.dir("MXNet", "examples", "mnist", "mnist-data.jl"))
|
||||
train_provider, eval_provider = get_mnist_providers(batch_size; flat=false)
|
||||
|
||||
model = mx.FeedForward(lenet, context = mx.gpu())
|
||||
|
||||
optimizer = mx.SGD(lr=0.05, momentum=0.9, weight_decay=0.00001)
|
||||
|
||||
mx.fit(model, optimizer, train_provider, n_epoch=1, eval_data=eval_provider)
|
|
@ -15,6 +15,7 @@ include("compiler/loops.jl")
|
|||
include("layers/dense.jl")
|
||||
include("layers/shape.jl")
|
||||
include("layers/chain.jl")
|
||||
include("layers/shims.jl")
|
||||
|
||||
include("cost.jl")
|
||||
include("activation.jl")
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
export σ, relu, softmax
|
||||
export σ, relu, softmax, flatten
|
||||
|
||||
σ(x) = 1 ./ (1 .+ exp.(-x))
|
||||
|
||||
|
@ -9,3 +9,7 @@ relu(x) = max(0, x)
|
|||
back!(::typeof(relu), Δ, x) = Δ .* (x .< 0)
|
||||
|
||||
softmax(xs) = exp.(xs) ./ sum(exp.(xs))
|
||||
|
||||
flatten(xs) = reshape(xs, length(xs))
|
||||
|
||||
shape(::typeof(flatten), in) = prod(in)
|
||||
|
|
|
@ -25,13 +25,49 @@ function graph(vars, model::Model, args...)
|
|||
end |> value
|
||||
end
|
||||
|
||||
type SoftmaxOutput
|
||||
name::Symbol
|
||||
end
|
||||
|
||||
function rewrite_softmax(model, name)
|
||||
model == softmax && return SoftmaxOutput(name)
|
||||
g = Flux.graph(model)
|
||||
(g == nothing || value(g) ≠ softmax || Flow.nin(g) ≠ 1) && error("mx.FeedForward models must end with `softmax`")
|
||||
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
|
||||
end
|
||||
|
||||
# Built-in implemenations
|
||||
|
||||
node(::typeof(*), args...) = mx.dot(args...)
|
||||
node(::typeof(+), args...) = mx.broadcast_plus(args...)
|
||||
node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
||||
node(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu)
|
||||
node(::typeof(softmax), xs) = mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
|
||||
node(::typeof(tanh), x) = mx.Activation(data = x, act_type=:tanh)
|
||||
node(::typeof(flatten), x) = mx.Flatten(data = x)
|
||||
|
||||
node(::typeof(softmax), xs) =
|
||||
mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
|
||||
|
||||
node(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(data = xs, name = s.name)
|
||||
|
||||
graph(vars, ::Input, x) = x
|
||||
graph(vars, ::Input, x, args...) = error("too many arguments to Input")
|
||||
|
||||
graph(vars, c::Conv, x) =
|
||||
mx.Convolution(data = x,
|
||||
kernel = c.size,
|
||||
num_filter = c.features,
|
||||
stride = c.stride)
|
||||
|
||||
graph(vars, p::MaxPool, x) =
|
||||
mx.Pooling(data = x,
|
||||
pool_type = :max,
|
||||
kernel = p.size,
|
||||
stride = p.stride)
|
||||
|
||||
# TODO: fix the initialisation issue
|
||||
graph(vars, d::Dense, x) =
|
||||
mx.FullyConnected(data = x,
|
||||
num_hidden = size(d.W.x, 1),
|
||||
# weight = graph(vars, d.W),
|
||||
# bias = graph(vars, d.b)
|
||||
)
|
||||
|
|
|
@ -47,9 +47,14 @@ function load!(model::MXModel)
|
|||
return model
|
||||
end
|
||||
|
||||
function mxnet(model::Model, input)
|
||||
function mxgraph(model, input)
|
||||
vars = Dict{Symbol,Any}()
|
||||
node = graph(vars, model, mx.Variable(:input))
|
||||
node = graph(vars, model, mx.Variable(input))
|
||||
return node, vars
|
||||
end
|
||||
|
||||
function mxnet(model::Model, input)
|
||||
node, vars = mxgraph(model, :input)
|
||||
args = merge(mxargs(vars), Dict(:input => mx.zeros(mxdims(input))))
|
||||
grads = mxgrads(args)
|
||||
model = MXModel(model, vars, grads,
|
||||
|
@ -81,3 +86,11 @@ function Flux.update!(model::MXModel, η)
|
|||
end
|
||||
return model
|
||||
end
|
||||
|
||||
# MX FeedForward interface
|
||||
|
||||
function mx.FeedForward(model::Model; input = :data, label = :softmax, context = mx.cpu())
|
||||
model = rewrite_softmax(model, label)
|
||||
node, _ = mxgraph(model, input)
|
||||
return mx.FeedForward(node, context = context)
|
||||
end
|
||||
|
|
|
@ -34,3 +34,5 @@ end
|
|||
|
||||
graph(s::Chain) =
|
||||
foldl((v, m) -> vertex(m, v), constant(ModelInput(1)), s.layers)
|
||||
|
||||
shape(c::Chain, in) = c.shape
|
||||
|
|
|
@ -11,6 +11,11 @@ single(i::Dims) = length(i) == 1 ? first(i) : i
|
|||
|
||||
# Shim for kicking off shape inference
|
||||
|
||||
type ShapeError <: Exception
|
||||
layer
|
||||
shape
|
||||
end
|
||||
|
||||
type Input{N} <: Model
|
||||
dims::Dims{N}
|
||||
end
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
export Conv, MaxPool
|
||||
|
||||
type Conv <: Model
|
||||
size::Dims{2}
|
||||
features::Int
|
||||
stride::Dims{2}
|
||||
end
|
||||
|
||||
Conv(size, features; stride = (1,1)) =
|
||||
Conv(size, features, stride)
|
||||
|
||||
shape(c::Conv, in::Dims{2}) =
|
||||
(map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2))..., c.features)
|
||||
|
||||
shape(c::Conv, in::Dims{3}) =
|
||||
shape(c, (in[1],in[2]))
|
||||
|
||||
type MaxPool <: Model
|
||||
size::Dims{2}
|
||||
stride::Dims{2}
|
||||
end
|
||||
|
||||
MaxPool(size; stride = (1,1)) =
|
||||
MaxPool(size, stride)
|
||||
|
||||
shape(c::MaxPool, in::Dims{2}) =
|
||||
map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2))
|
||||
|
||||
shape(c::MaxPool, in::Dims{3}) =
|
||||
(shape(c, (in[1],in[2]))..., in[3])
|
||||
|
||||
shape(c::MaxPool, in) = throw(ShapeError(c, in))
|
Loading…
Reference in New Issue