basic convnet example working
This commit is contained in:
parent
205e1215d6
commit
8961b4c10f
47
examples/mnist-conv.jl
Normal file
47
examples/mnist-conv.jl
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
using Flux, MXNet
|
||||||
|
|
||||||
|
# Flux aims to provide high-level APIs that work well across backends, but in
|
||||||
|
# some cases you may want to take advantage of features specific to a given
|
||||||
|
# backend (or alternatively, Flux may simply not have an implementation of that
|
||||||
|
# feature yet). In these cases it's easy to "drop down" and use the backend's
|
||||||
|
# API directly, where appropriate.
|
||||||
|
|
||||||
|
# In this example, both things are happening; firstly, Flux doesn't yet support
|
||||||
|
# ConvNets in the pure-Julia backend, but this is invisible thanks to the use of
|
||||||
|
# a simple "shim" type, `Conv`. This is provided by the library but could easily
|
||||||
|
# have been user-defined.
|
||||||
|
|
||||||
|
# Secondly, we want to take advantage of MXNet.jl's training process and
|
||||||
|
# optimisers. We can simply call `mx.FeedForward` exactly as we would on a
|
||||||
|
# regular MXNet model, and the rest of the process is trivial.
|
||||||
|
|
||||||
|
conv1 = Chain(
|
||||||
|
Input(28,28),
|
||||||
|
Conv((5,5),20), tanh,
|
||||||
|
MaxPool((2,2), stride = (2,2)))
|
||||||
|
|
||||||
|
conv2 = Chain(
|
||||||
|
conv1,
|
||||||
|
Conv((5,5),50), tanh,
|
||||||
|
MaxPool((2,2), stride = (2,2)))
|
||||||
|
|
||||||
|
lenet = Chain(
|
||||||
|
conv2,
|
||||||
|
flatten,
|
||||||
|
Dense(500), tanh,
|
||||||
|
Dense(10), softmax)
|
||||||
|
|
||||||
|
#--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Now we can continue exactly as in plain MXNet, following
|
||||||
|
# https://github.com/dmlc/MXNet.jl/blob/master/examples/mnist/lenet.jl
|
||||||
|
|
||||||
|
batch_size = 100
|
||||||
|
include(Pkg.dir("MXNet", "examples", "mnist", "mnist-data.jl"))
|
||||||
|
train_provider, eval_provider = get_mnist_providers(batch_size; flat=false)
|
||||||
|
|
||||||
|
model = mx.FeedForward(lenet, context = mx.gpu())
|
||||||
|
|
||||||
|
optimizer = mx.SGD(lr=0.05, momentum=0.9, weight_decay=0.00001)
|
||||||
|
|
||||||
|
mx.fit(model, optimizer, train_provider, n_epoch=1, eval_data=eval_provider)
|
@ -15,6 +15,7 @@ include("compiler/loops.jl")
|
|||||||
include("layers/dense.jl")
|
include("layers/dense.jl")
|
||||||
include("layers/shape.jl")
|
include("layers/shape.jl")
|
||||||
include("layers/chain.jl")
|
include("layers/chain.jl")
|
||||||
|
include("layers/shims.jl")
|
||||||
|
|
||||||
include("cost.jl")
|
include("cost.jl")
|
||||||
include("activation.jl")
|
include("activation.jl")
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
export σ, relu, softmax
|
export σ, relu, softmax, flatten
|
||||||
|
|
||||||
σ(x) = 1 ./ (1 .+ exp.(-x))
|
σ(x) = 1 ./ (1 .+ exp.(-x))
|
||||||
|
|
||||||
@ -9,3 +9,7 @@ relu(x) = max(0, x)
|
|||||||
back!(::typeof(relu), Δ, x) = Δ .* (x .< 0)
|
back!(::typeof(relu), Δ, x) = Δ .* (x .< 0)
|
||||||
|
|
||||||
softmax(xs) = exp.(xs) ./ sum(exp.(xs))
|
softmax(xs) = exp.(xs) ./ sum(exp.(xs))
|
||||||
|
|
||||||
|
flatten(xs) = reshape(xs, length(xs))
|
||||||
|
|
||||||
|
shape(::typeof(flatten), in) = prod(in)
|
||||||
|
@ -25,13 +25,49 @@ function graph(vars, model::Model, args...)
|
|||||||
end |> value
|
end |> value
|
||||||
end
|
end
|
||||||
|
|
||||||
|
type SoftmaxOutput
|
||||||
|
name::Symbol
|
||||||
|
end
|
||||||
|
|
||||||
|
function rewrite_softmax(model, name)
|
||||||
|
model == softmax && return SoftmaxOutput(name)
|
||||||
|
g = Flux.graph(model)
|
||||||
|
(g == nothing || value(g) ≠ softmax || Flow.nin(g) ≠ 1) && error("mx.FeedForward models must end with `softmax`")
|
||||||
|
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
|
||||||
|
end
|
||||||
|
|
||||||
# Built-in implemenations
|
# Built-in implemenations
|
||||||
|
|
||||||
node(::typeof(*), args...) = mx.dot(args...)
|
node(::typeof(*), args...) = mx.dot(args...)
|
||||||
node(::typeof(+), args...) = mx.broadcast_plus(args...)
|
node(::typeof(+), args...) = mx.broadcast_plus(args...)
|
||||||
node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
||||||
node(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu)
|
node(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu)
|
||||||
node(::typeof(softmax), xs) = mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
|
node(::typeof(tanh), x) = mx.Activation(data = x, act_type=:tanh)
|
||||||
|
node(::typeof(flatten), x) = mx.Flatten(data = x)
|
||||||
|
|
||||||
|
node(::typeof(softmax), xs) =
|
||||||
|
mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
|
||||||
|
|
||||||
|
node(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(data = xs, name = s.name)
|
||||||
|
|
||||||
graph(vars, ::Input, x) = x
|
graph(vars, ::Input, x) = x
|
||||||
graph(vars, ::Input, x, args...) = error("too many arguments to Input")
|
|
||||||
|
graph(vars, c::Conv, x) =
|
||||||
|
mx.Convolution(data = x,
|
||||||
|
kernel = c.size,
|
||||||
|
num_filter = c.features,
|
||||||
|
stride = c.stride)
|
||||||
|
|
||||||
|
graph(vars, p::MaxPool, x) =
|
||||||
|
mx.Pooling(data = x,
|
||||||
|
pool_type = :max,
|
||||||
|
kernel = p.size,
|
||||||
|
stride = p.stride)
|
||||||
|
|
||||||
|
# TODO: fix the initialisation issue
|
||||||
|
graph(vars, d::Dense, x) =
|
||||||
|
mx.FullyConnected(data = x,
|
||||||
|
num_hidden = size(d.W.x, 1),
|
||||||
|
# weight = graph(vars, d.W),
|
||||||
|
# bias = graph(vars, d.b)
|
||||||
|
)
|
||||||
|
@ -47,9 +47,14 @@ function load!(model::MXModel)
|
|||||||
return model
|
return model
|
||||||
end
|
end
|
||||||
|
|
||||||
function mxnet(model::Model, input)
|
function mxgraph(model, input)
|
||||||
vars = Dict{Symbol,Any}()
|
vars = Dict{Symbol,Any}()
|
||||||
node = graph(vars, model, mx.Variable(:input))
|
node = graph(vars, model, mx.Variable(input))
|
||||||
|
return node, vars
|
||||||
|
end
|
||||||
|
|
||||||
|
function mxnet(model::Model, input)
|
||||||
|
node, vars = mxgraph(model, :input)
|
||||||
args = merge(mxargs(vars), Dict(:input => mx.zeros(mxdims(input))))
|
args = merge(mxargs(vars), Dict(:input => mx.zeros(mxdims(input))))
|
||||||
grads = mxgrads(args)
|
grads = mxgrads(args)
|
||||||
model = MXModel(model, vars, grads,
|
model = MXModel(model, vars, grads,
|
||||||
@ -81,3 +86,11 @@ function Flux.update!(model::MXModel, η)
|
|||||||
end
|
end
|
||||||
return model
|
return model
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# MX FeedForward interface
|
||||||
|
|
||||||
|
function mx.FeedForward(model::Model; input = :data, label = :softmax, context = mx.cpu())
|
||||||
|
model = rewrite_softmax(model, label)
|
||||||
|
node, _ = mxgraph(model, input)
|
||||||
|
return mx.FeedForward(node, context = context)
|
||||||
|
end
|
||||||
|
@ -34,3 +34,5 @@ end
|
|||||||
|
|
||||||
graph(s::Chain) =
|
graph(s::Chain) =
|
||||||
foldl((v, m) -> vertex(m, v), constant(ModelInput(1)), s.layers)
|
foldl((v, m) -> vertex(m, v), constant(ModelInput(1)), s.layers)
|
||||||
|
|
||||||
|
shape(c::Chain, in) = c.shape
|
||||||
|
@ -11,6 +11,11 @@ single(i::Dims) = length(i) == 1 ? first(i) : i
|
|||||||
|
|
||||||
# Shim for kicking off shape inference
|
# Shim for kicking off shape inference
|
||||||
|
|
||||||
|
type ShapeError <: Exception
|
||||||
|
layer
|
||||||
|
shape
|
||||||
|
end
|
||||||
|
|
||||||
type Input{N} <: Model
|
type Input{N} <: Model
|
||||||
dims::Dims{N}
|
dims::Dims{N}
|
||||||
end
|
end
|
||||||
|
32
src/layers/shims.jl
Normal file
32
src/layers/shims.jl
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
export Conv, MaxPool
|
||||||
|
|
||||||
|
type Conv <: Model
|
||||||
|
size::Dims{2}
|
||||||
|
features::Int
|
||||||
|
stride::Dims{2}
|
||||||
|
end
|
||||||
|
|
||||||
|
Conv(size, features; stride = (1,1)) =
|
||||||
|
Conv(size, features, stride)
|
||||||
|
|
||||||
|
shape(c::Conv, in::Dims{2}) =
|
||||||
|
(map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2))..., c.features)
|
||||||
|
|
||||||
|
shape(c::Conv, in::Dims{3}) =
|
||||||
|
shape(c, (in[1],in[2]))
|
||||||
|
|
||||||
|
type MaxPool <: Model
|
||||||
|
size::Dims{2}
|
||||||
|
stride::Dims{2}
|
||||||
|
end
|
||||||
|
|
||||||
|
MaxPool(size; stride = (1,1)) =
|
||||||
|
MaxPool(size, stride)
|
||||||
|
|
||||||
|
shape(c::MaxPool, in::Dims{2}) =
|
||||||
|
map(i -> (in[i]-c.size[i])÷c.stride[i]+1, (1,2))
|
||||||
|
|
||||||
|
shape(c::MaxPool, in::Dims{3}) =
|
||||||
|
(shape(c, (in[1],in[2]))..., in[3])
|
||||||
|
|
||||||
|
shape(c::MaxPool, in) = throw(ShapeError(c, in))
|
Loading…
Reference in New Issue
Block a user