a bunch of stuff

This commit is contained in:
Mike J Innes 2016-08-25 22:49:21 +01:00
parent 4249e6e961
commit c92cff5dce
12 changed files with 114 additions and 66 deletions

View File

@ -1,6 +1,4 @@
using Flux, MNIST, Flow, MacroTools
import Flux.MX: mxnet
import Flux: back!, update!, graph
using Flux, MNIST
@time begin
const data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
@ -11,9 +9,9 @@ end
m = Chain(
Input(784),
Dense(784, 128), relu,
Dense(128, 64), relu,
Dense(64, 10), softmax)
Dense(128), relu,
Dense( 64), relu,
Dense( 10), softmax)
model = mxnet(m, 784)

View File

@ -10,12 +10,13 @@ include("utils.jl")
include("compiler/diff.jl")
include("compiler/code.jl")
include("layers/dense.jl")
include("layers/shape.jl")
include("layers/chain.jl")
include("cost.jl")
include("activation.jl")
include("layers/input.jl")
include("layers/dense.jl")
include("layers/sequence.jl")
include("backend/mxnet/mxnet.jl")
include("backend/backend.jl")
end # module

View File

@ -8,4 +8,4 @@ relu(x) = max(0, x)
back(::typeof(relu), Δ, x) = Δ .* (x .< 0)
softmax(x) = error("not implemented")
softmax(xs) = exp.(xs) ./ sum(exp.(xs))

7
src/backend/backend.jl Normal file
View File

@ -0,0 +1,7 @@
# TODO: load backends lazily
include("mxnet/mxnet.jl")
using .MX
export mxnet

View File

@ -32,3 +32,6 @@ node(::typeof(+), args...) = mx.broadcast_plus(args...)
node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
node(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu)
node(::typeof(softmax), xs) = mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
graph(vars, ::Input, x) = x
graph(vars, ::Input, x, args...) = error("too many arguments to Input")

View File

@ -2,6 +2,8 @@ module MX
using MXNet, Flow, ..Flux
export mxnet
include("graph.jl")
include("model.jl")

30
src/layers/chain.jl Normal file
View File

@ -0,0 +1,30 @@
export Chain
function inferchain(ms)
chain = []
sh = nothing
for m in ms
m = init(m, single(sh))
sh = shape(m, sh)
push!(chain, m)
end
return chain, sh
end
type Chain <: Model
layers::Vector{Any}
shape
function Chain(ms...)
ms, shape = inferchain(ms)
return new(ms, shape)
end
end
@forward Chain.layers Base.getindex, Base.first, Base.last
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
back!(s::Chain, ) = foldr((m, ) -> back!(m, ), , s.layers)
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
graph(s::Chain) =
foldl((v, m) -> vertex(m, v), constant(ModelInput(1)), s.layers)

View File

@ -1,5 +1,7 @@
export Dense
# TODO: type hints for parameters
@model type Dense
W
b
@ -9,7 +11,8 @@ end
Dense(in::Integer, out::Integer; init = initn) =
Dense(init(out, in), init(out))
Base.show(io::IO, ::Dense) = print(io, "Flux.Dense(...)")
Base.show(io::IO, d::Dense) =
print(io, "Flux.Dense($(size(d.W.x,2)),$(size(d.W.x,1)))")
@model type Sigmoid
layer::Model
@ -18,3 +21,13 @@ end
Sigmoid(in::Integer, out::Integer; init = randn) =
Sigmoid(Dense(in, out, init = init))
# @model type Recurrent
# Wxh; Whh; Bh
# Wxy; Why; By
#
# function (x)
# hidden = σ( Wxh*x + Whh*hidden + Bh )
# y = σ( Wxy*x + Why*hidden + By )
# end
# end

View File

@ -1,26 +0,0 @@
export Input
typealias Dims{N} NTuple{N,Int}
dims(d::Dims) = d
dims(i...) = (i...,)
type Input{N} <: Model
dims::Dims{N}
end
Input(i...) = Input(dims(i...))
(::Input)(x) = x
back!(::Input, , x) =
shape(i::Input) = i.dims
# Initialise placeholder
type Init{F}
f::F
end
(f::Init)(args...) = f.f(args...)

View File

@ -1,23 +0,0 @@
export Sequence
type Sequence <: Model
layers::Vector{Model}
end
Sequence() = Sequence([])
@forward Sequence.layers Base.getindex, Base.first, Base.last
Base.push!(s::Sequence, m::Model) = push!(s.layers, m)
Base.push!(s::Sequence, f::Init) = push!(s, f(shape(last(s))))
function Sequence(ms...)
s = Sequence()
foreach(m -> push!(s, m), ms)
return s
end
(s::Sequence)(x) = foldl((x, m) -> m(x), x, s.layers)
back!(s::Sequence, ) = foldr((m, ) -> back!(m, ), , s.layers)
update!(s::Sequence, η) = foreach(l -> update!(l, η), s.layers)

42
src/layers/shape.jl Normal file
View File

@ -0,0 +1,42 @@
export Input
# Shim for kicking off shape inference
typealias Dims{N} NTuple{N,Int}
dims(d::Dims) = d
dims(i...) = (i...,)
single(i) = i
single(i::Dims) = length(i) == 1 ? first(i) : i
type Input{N} <: Model
dims::Dims{N}
end
Input(i...) = Input(dims(i...))
(::Input)(x) = x
back!(::Input, , x) =
# Initialise placeholder
type Init{F}
f::F
end
init(i::Init, input...) = i.f(input...)
init(m, input...) = m
# Shape inference API
shape(x, in) = in
shape(i::Input, _) = i.dims
# Implementation for bundled layers
shape(d::Dense, _) = length(state(d.b)) # TODO: could perhaps infer this
Dense(out::Integer) = Init(in::Integer -> Dense(in, out))

View File

@ -5,19 +5,20 @@ const AArray = AbstractArray
onehot(label, labels) = [i == label for i in labels]
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
function train!(m::Model, train, test = []; epoch = 1, batch = 10, η = 0.1)
initn(dims...) = randn(dims...)/100
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
i = 0
= zeros(length(train[1][2]))
for _ in 1:epoch
for (x, y) in shuffle!(train)
for (x, y) in train
i += 1
err = mse!(, m(x), y)
pred = m(x)
any(isnan, pred) && error("NaN")
err = mse!(, pred, y)
back!(m, , x)
i % batch == 0 && update!(m, η/batch)
i % batch == 0 && (update!(m, η); @show accuracy(m, test))
end
@show accuracy(m, test)
end
return m
end