initial torch-like, MNIST demo

This commit is contained in:
Mike J Innes 2016-05-10 17:06:31 +01:00
parent 0b5aad84fc
commit 8314da4207
8 changed files with 182 additions and 3 deletions

15
examples/MNIST.jl Normal file
View File

@ -0,0 +1,15 @@
using Flux, MNIST
const data = collect(zip([trainfeatures(i) for i = 1:60_000],
[onehot(trainlabel(i), 1:10) for i = 1:60_000]))
const train = data[1:50_000]
const test = data[50_001:60_000]
const m = Sequence(
Input(784),
Dense(30),
Sigmoid(),
Dense(10),
Sigmoid())
Flux.train!(m, epoch = 30)

View File

@ -1,10 +1,23 @@
module Flux
using Lazy, Flow
# Zero Flux Given
abstract Capacitor
export Model, back!, update!
macro flux(x)
end
abstract Model
abstract Capacitor <: Model
abstract Activation <: Model
back!(m::Model, ) = error("Backprop not implemented for $(typeof(m))")
update!(m::Model, η) = m
include("utils.jl")
include("cost.jl")
include("activation.jl")
include("layers/input.jl")
include("layers/dense.jl")
include("layers/sequence.jl")
end # module

26
src/activation.jl Normal file
View File

@ -0,0 +1,26 @@
export Sigmoid
σ(x) = 1/(1+exp(-x))
σ(x) = σ(x)*(1-σ(x))
type Sigmoid <: Activation
in::Vector{Float32}
out::Vector{Float32}
∇in::Vector{Float32}
end
Sigmoid(size::Integer) = Sigmoid(zeros(size), zeros(size), zeros(size))
function (l::Sigmoid)(x)
l.in = x
map!(σ, l.out, x)
end
function back!(l::Sigmoid, )
map!(σ, l.∇in, l.in)
map!(*, l.∇in, l.∇in, )
end
shape(l::Sigmoid) = length(l.in)
Sigmoid() = Init(in -> Sigmoid(in[1]))

8
src/cost.jl Normal file
View File

@ -0,0 +1,8 @@
export mse, mse!
function mse!(, pred, target)
map!(-, , pred, target)
sumabs2()/2
end
mse(pred, target) = mse(similar(pred), pred, target)

41
src/layers/dense.jl Normal file
View File

@ -0,0 +1,41 @@
export Dense
type Dense <: Model
W::Matrix{Float32}
b::Vector{Float32}
∇W::Matrix{Float32}
∇b::Vector{Float32}
in::Vector{Float32}
out::Vector{Float32}
∇in::Vector{Float32}
end
Dense(in::Integer, out::Integer) =
Dense(randn(out, in), randn(out),
zeros(out, in), zeros(out),
zeros(in), zeros(out), zeros(in))
Dense(out::Integer) = Init(in -> Dense(in[1], out))
function (l::Dense)(x)
l.in = x
A_mul_B!(l.out, l.W, x)
map!(+, l.out, l.out, l.b)
end
function back!(l::Dense, )
map!(+, l.∇b, l.∇b, )
# l.∇W += ∇ * l.in'
BLAS.gemm!('N', 'T', eltype()(1), , l.in, eltype()(1), l.∇W)
At_mul_B!(l.∇in, l.W, )
end
function update!(l::Dense, η)
map!((x, ∇x) -> x - η*∇x, l.W, l.W, l.∇W)
map!((x, ∇x) -> x - η*∇x, l.b, l.b, l.∇b)
fill!(l.∇W, 0)
fill!(l.∇b, 0)
end
shape(d::Dense) = size(d.b)

26
src/layers/input.jl Normal file
View File

@ -0,0 +1,26 @@
export Input
typealias Dims{N} NTuple{N,Int}
dims(d::Dims) = d
dims(i...) = (i...,)
type Input{N} <: Model
dims::Dims{N}
end
Input(i) = Input(dims(i))
(::Input)(x) = x
back!(::Input, ) =
shape(i::Input) = i.dims
# Initialise placeholder
type Init{F}
f::F
end
(f::Init)(args...) = f.f(args...)

23
src/layers/sequence.jl Normal file
View File

@ -0,0 +1,23 @@
export Sequence
type Sequence <: Capacitor
layers::Vector{Model}
end
Sequence() = Sequence([])
@forward Sequence.layers Base.getindex, Base.first, Base.last
Base.push!(s::Sequence, m::Model) = push!(s.layers, m)
Base.push!(s::Sequence, f::Init) = push!(s, f(shape(last(s))))
function Sequence(ms...)
s = Sequence()
foreach(m -> push!(s, m), ms)
return s
end
(s::Sequence)(x) = foldl((x, m) -> m(x), x, s.layers)
back!(s::Sequence, ) = foldr((m, ) -> back!(m, ), , s.layers)
update!(s::Sequence, η) = foreach(l -> update!(l, η), s.layers)

27
src/utils.jl Normal file
View File

@ -0,0 +1,27 @@
export onehot, onecold
onehot(label, labels) = [i == label for i in labels]
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
function train!(m::Model, train, test = []; epoch = 1, batch = 10, η = 0.1)
i = 0
= zeros(length(train[1][2]))
for _ in 1:epoch
for (x, y) in shuffle!(train)
i += 1
err = mse!(, m(x), y)
back!(m, )
i % batch == 0 && update!(m, η/batch)
end
@show accuracy(m, test)
end
return m
end
function accuracy(m::Model, data)
correct = 0
for (x, y) in data
onecold(m(x)) == onecold(y) && (correct += 1)
end
return correct/length(data)
end