initial torch-like, MNIST demo
This commit is contained in:
parent
0b5aad84fc
commit
8314da4207
15
examples/MNIST.jl
Normal file
15
examples/MNIST.jl
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
using Flux, MNIST
|
||||||
|
|
||||||
|
const data = collect(zip([trainfeatures(i) for i = 1:60_000],
|
||||||
|
[onehot(trainlabel(i), 1:10) for i = 1:60_000]))
|
||||||
|
const train = data[1:50_000]
|
||||||
|
const test = data[50_001:60_000]
|
||||||
|
|
||||||
|
const m = Sequence(
|
||||||
|
Input(784),
|
||||||
|
Dense(30),
|
||||||
|
Sigmoid(),
|
||||||
|
Dense(10),
|
||||||
|
Sigmoid())
|
||||||
|
|
||||||
|
Flux.train!(m, epoch = 30)
|
19
src/Flux.jl
19
src/Flux.jl
@ -1,10 +1,23 @@
|
|||||||
module Flux
|
module Flux
|
||||||
|
|
||||||
|
using Lazy, Flow
|
||||||
|
|
||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
|
||||||
abstract Capacitor
|
export Model, back!, update!
|
||||||
|
|
||||||
macro flux(x)
|
abstract Model
|
||||||
end
|
abstract Capacitor <: Model
|
||||||
|
abstract Activation <: Model
|
||||||
|
|
||||||
|
back!(m::Model, ∇) = error("Backprop not implemented for $(typeof(m))")
|
||||||
|
update!(m::Model, η) = m
|
||||||
|
|
||||||
|
include("utils.jl")
|
||||||
|
include("cost.jl")
|
||||||
|
include("activation.jl")
|
||||||
|
include("layers/input.jl")
|
||||||
|
include("layers/dense.jl")
|
||||||
|
include("layers/sequence.jl")
|
||||||
|
|
||||||
end # module
|
end # module
|
||||||
|
26
src/activation.jl
Normal file
26
src/activation.jl
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
export Sigmoid
|
||||||
|
|
||||||
|
σ(x) = 1/(1+exp(-x))
|
||||||
|
σ′(x) = σ(x)*(1-σ(x))
|
||||||
|
|
||||||
|
type Sigmoid <: Activation
|
||||||
|
in::Vector{Float32}
|
||||||
|
out::Vector{Float32}
|
||||||
|
∇in::Vector{Float32}
|
||||||
|
end
|
||||||
|
|
||||||
|
Sigmoid(size::Integer) = Sigmoid(zeros(size), zeros(size), zeros(size))
|
||||||
|
|
||||||
|
function (l::Sigmoid)(x)
|
||||||
|
l.in = x
|
||||||
|
map!(σ, l.out, x)
|
||||||
|
end
|
||||||
|
|
||||||
|
function back!(l::Sigmoid, ∇)
|
||||||
|
map!(σ′, l.∇in, l.in)
|
||||||
|
map!(*, l.∇in, l.∇in, ∇)
|
||||||
|
end
|
||||||
|
|
||||||
|
shape(l::Sigmoid) = length(l.in)
|
||||||
|
|
||||||
|
Sigmoid() = Init(in -> Sigmoid(in[1]))
|
8
src/cost.jl
Normal file
8
src/cost.jl
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
export mse, mse!
|
||||||
|
|
||||||
|
function mse!(∇, pred, target)
|
||||||
|
map!(-, ∇, pred, target)
|
||||||
|
sumabs2(∇)/2
|
||||||
|
end
|
||||||
|
|
||||||
|
mse(pred, target) = mse(similar(pred), pred, target)
|
41
src/layers/dense.jl
Normal file
41
src/layers/dense.jl
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
export Dense
|
||||||
|
|
||||||
|
type Dense <: Model
|
||||||
|
W::Matrix{Float32}
|
||||||
|
b::Vector{Float32}
|
||||||
|
∇W::Matrix{Float32}
|
||||||
|
∇b::Vector{Float32}
|
||||||
|
|
||||||
|
in::Vector{Float32}
|
||||||
|
out::Vector{Float32}
|
||||||
|
∇in::Vector{Float32}
|
||||||
|
end
|
||||||
|
|
||||||
|
Dense(in::Integer, out::Integer) =
|
||||||
|
Dense(randn(out, in), randn(out),
|
||||||
|
zeros(out, in), zeros(out),
|
||||||
|
zeros(in), zeros(out), zeros(in))
|
||||||
|
|
||||||
|
Dense(out::Integer) = Init(in -> Dense(in[1], out))
|
||||||
|
|
||||||
|
function (l::Dense)(x)
|
||||||
|
l.in = x
|
||||||
|
A_mul_B!(l.out, l.W, x)
|
||||||
|
map!(+, l.out, l.out, l.b)
|
||||||
|
end
|
||||||
|
|
||||||
|
function back!(l::Dense, ∇)
|
||||||
|
map!(+, l.∇b, l.∇b, ∇)
|
||||||
|
# l.∇W += ∇ * l.in'
|
||||||
|
BLAS.gemm!('N', 'T', eltype(∇)(1), ∇, l.in, eltype(∇)(1), l.∇W)
|
||||||
|
At_mul_B!(l.∇in, l.W, ∇)
|
||||||
|
end
|
||||||
|
|
||||||
|
function update!(l::Dense, η)
|
||||||
|
map!((x, ∇x) -> x - η*∇x, l.W, l.W, l.∇W)
|
||||||
|
map!((x, ∇x) -> x - η*∇x, l.b, l.b, l.∇b)
|
||||||
|
fill!(l.∇W, 0)
|
||||||
|
fill!(l.∇b, 0)
|
||||||
|
end
|
||||||
|
|
||||||
|
shape(d::Dense) = size(d.b)
|
26
src/layers/input.jl
Normal file
26
src/layers/input.jl
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
export Input
|
||||||
|
|
||||||
|
typealias Dims{N} NTuple{N,Int}
|
||||||
|
|
||||||
|
dims(d::Dims) = d
|
||||||
|
|
||||||
|
dims(i...) = (i...,)
|
||||||
|
|
||||||
|
type Input{N} <: Model
|
||||||
|
dims::Dims{N}
|
||||||
|
end
|
||||||
|
|
||||||
|
Input(i) = Input(dims(i))
|
||||||
|
|
||||||
|
(::Input)(x) = x
|
||||||
|
back!(::Input, ∇) = ∇
|
||||||
|
|
||||||
|
shape(i::Input) = i.dims
|
||||||
|
|
||||||
|
# Initialise placeholder
|
||||||
|
|
||||||
|
type Init{F}
|
||||||
|
f::F
|
||||||
|
end
|
||||||
|
|
||||||
|
(f::Init)(args...) = f.f(args...)
|
23
src/layers/sequence.jl
Normal file
23
src/layers/sequence.jl
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
export Sequence
|
||||||
|
|
||||||
|
type Sequence <: Capacitor
|
||||||
|
layers::Vector{Model}
|
||||||
|
end
|
||||||
|
|
||||||
|
Sequence() = Sequence([])
|
||||||
|
|
||||||
|
@forward Sequence.layers Base.getindex, Base.first, Base.last
|
||||||
|
|
||||||
|
Base.push!(s::Sequence, m::Model) = push!(s.layers, m)
|
||||||
|
|
||||||
|
Base.push!(s::Sequence, f::Init) = push!(s, f(shape(last(s))))
|
||||||
|
|
||||||
|
function Sequence(ms...)
|
||||||
|
s = Sequence()
|
||||||
|
foreach(m -> push!(s, m), ms)
|
||||||
|
return s
|
||||||
|
end
|
||||||
|
|
||||||
|
(s::Sequence)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||||
|
back!(s::Sequence, ∇) = foldr((m, ∇) -> back!(m, ∇), ∇, s.layers)
|
||||||
|
update!(s::Sequence, η) = foreach(l -> update!(l, η), s.layers)
|
27
src/utils.jl
Normal file
27
src/utils.jl
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
export onehot, onecold
|
||||||
|
|
||||||
|
onehot(label, labels) = [i == label for i in labels]
|
||||||
|
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
|
||||||
|
|
||||||
|
function train!(m::Model, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
||||||
|
i = 0
|
||||||
|
∇ = zeros(length(train[1][2]))
|
||||||
|
for _ in 1:epoch
|
||||||
|
for (x, y) in shuffle!(train)
|
||||||
|
i += 1
|
||||||
|
err = mse!(∇, m(x), y)
|
||||||
|
back!(m, ∇)
|
||||||
|
i % batch == 0 && update!(m, η/batch)
|
||||||
|
end
|
||||||
|
@show accuracy(m, test)
|
||||||
|
end
|
||||||
|
return m
|
||||||
|
end
|
||||||
|
|
||||||
|
function accuracy(m::Model, data)
|
||||||
|
correct = 0
|
||||||
|
for (x, y) in data
|
||||||
|
onecold(m(x)) == onecold(y) && (correct += 1)
|
||||||
|
end
|
||||||
|
return correct/length(data)
|
||||||
|
end
|
Loading…
Reference in New Issue
Block a user