a bunch of stuff
This commit is contained in:
parent
4249e6e961
commit
c92cff5dce
@ -1,6 +1,4 @@
|
|||||||
using Flux, MNIST, Flow, MacroTools
|
using Flux, MNIST
|
||||||
import Flux.MX: mxnet
|
|
||||||
import Flux: back!, update!, graph
|
|
||||||
|
|
||||||
@time begin
|
@time begin
|
||||||
const data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
|
const data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
|
||||||
@ -11,9 +9,9 @@ end
|
|||||||
|
|
||||||
m = Chain(
|
m = Chain(
|
||||||
Input(784),
|
Input(784),
|
||||||
Dense(784, 128), relu,
|
Dense(128), relu,
|
||||||
Dense(128, 64), relu,
|
Dense( 64), relu,
|
||||||
Dense(64, 10), softmax)
|
Dense( 10), softmax)
|
||||||
|
|
||||||
model = mxnet(m, 784)
|
model = mxnet(m, 784)
|
||||||
|
|
||||||
|
@ -10,12 +10,13 @@ include("utils.jl")
|
|||||||
include("compiler/diff.jl")
|
include("compiler/diff.jl")
|
||||||
include("compiler/code.jl")
|
include("compiler/code.jl")
|
||||||
|
|
||||||
|
include("layers/dense.jl")
|
||||||
|
include("layers/shape.jl")
|
||||||
|
include("layers/chain.jl")
|
||||||
|
|
||||||
include("cost.jl")
|
include("cost.jl")
|
||||||
include("activation.jl")
|
include("activation.jl")
|
||||||
include("layers/input.jl")
|
|
||||||
include("layers/dense.jl")
|
|
||||||
include("layers/sequence.jl")
|
|
||||||
|
|
||||||
include("backend/mxnet/mxnet.jl")
|
include("backend/backend.jl")
|
||||||
|
|
||||||
end # module
|
end # module
|
||||||
|
@ -8,4 +8,4 @@ relu(x) = max(0, x)
|
|||||||
|
|
||||||
back(::typeof(relu), Δ, x) = Δ .* (x .< 0)
|
back(::typeof(relu), Δ, x) = Δ .* (x .< 0)
|
||||||
|
|
||||||
softmax(x) = error("not implemented")
|
softmax(xs) = exp.(xs) ./ sum(exp.(xs))
|
||||||
|
7
src/backend/backend.jl
Normal file
7
src/backend/backend.jl
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# TODO: load backends lazily
|
||||||
|
|
||||||
|
include("mxnet/mxnet.jl")
|
||||||
|
|
||||||
|
using .MX
|
||||||
|
|
||||||
|
export mxnet
|
@ -32,3 +32,6 @@ node(::typeof(+), args...) = mx.broadcast_plus(args...)
|
|||||||
node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
||||||
node(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu)
|
node(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu)
|
||||||
node(::typeof(softmax), xs) = mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
|
node(::typeof(softmax), xs) = mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
|
||||||
|
|
||||||
|
graph(vars, ::Input, x) = x
|
||||||
|
graph(vars, ::Input, x, args...) = error("too many arguments to Input")
|
||||||
|
@ -2,6 +2,8 @@ module MX
|
|||||||
|
|
||||||
using MXNet, Flow, ..Flux
|
using MXNet, Flow, ..Flux
|
||||||
|
|
||||||
|
export mxnet
|
||||||
|
|
||||||
include("graph.jl")
|
include("graph.jl")
|
||||||
include("model.jl")
|
include("model.jl")
|
||||||
|
|
||||||
|
30
src/layers/chain.jl
Normal file
30
src/layers/chain.jl
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
export Chain
|
||||||
|
|
||||||
|
function inferchain(ms)
|
||||||
|
chain = []
|
||||||
|
sh = nothing
|
||||||
|
for m in ms
|
||||||
|
m = init(m, single(sh))
|
||||||
|
sh = shape(m, sh)
|
||||||
|
push!(chain, m)
|
||||||
|
end
|
||||||
|
return chain, sh
|
||||||
|
end
|
||||||
|
|
||||||
|
type Chain <: Model
|
||||||
|
layers::Vector{Any}
|
||||||
|
shape
|
||||||
|
function Chain(ms...)
|
||||||
|
ms, shape = inferchain(ms)
|
||||||
|
return new(ms, shape)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
@forward Chain.layers Base.getindex, Base.first, Base.last
|
||||||
|
|
||||||
|
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||||
|
back!(s::Chain, ∇) = foldr((m, ∇) -> back!(m, ∇), ∇, s.layers)
|
||||||
|
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
|
||||||
|
|
||||||
|
graph(s::Chain) =
|
||||||
|
foldl((v, m) -> vertex(m, v), constant(ModelInput(1)), s.layers)
|
@ -1,5 +1,7 @@
|
|||||||
export Dense
|
export Dense
|
||||||
|
|
||||||
|
# TODO: type hints for parameters
|
||||||
|
|
||||||
@model type Dense
|
@model type Dense
|
||||||
W
|
W
|
||||||
b
|
b
|
||||||
@ -9,7 +11,8 @@ end
|
|||||||
Dense(in::Integer, out::Integer; init = initn) =
|
Dense(in::Integer, out::Integer; init = initn) =
|
||||||
Dense(init(out, in), init(out))
|
Dense(init(out, in), init(out))
|
||||||
|
|
||||||
Base.show(io::IO, ::Dense) = print(io, "Flux.Dense(...)")
|
Base.show(io::IO, d::Dense) =
|
||||||
|
print(io, "Flux.Dense($(size(d.W.x,2)),$(size(d.W.x,1)))")
|
||||||
|
|
||||||
@model type Sigmoid
|
@model type Sigmoid
|
||||||
layer::Model
|
layer::Model
|
||||||
@ -18,3 +21,13 @@ end
|
|||||||
|
|
||||||
Sigmoid(in::Integer, out::Integer; init = randn) =
|
Sigmoid(in::Integer, out::Integer; init = randn) =
|
||||||
Sigmoid(Dense(in, out, init = init))
|
Sigmoid(Dense(in, out, init = init))
|
||||||
|
|
||||||
|
# @model type Recurrent
|
||||||
|
# Wxh; Whh; Bh
|
||||||
|
# Wxy; Why; By
|
||||||
|
#
|
||||||
|
# function (x)
|
||||||
|
# hidden = σ( Wxh*x + Whh*hidden + Bh )
|
||||||
|
# y = σ( Wxy*x + Why*hidden + By )
|
||||||
|
# end
|
||||||
|
# end
|
||||||
|
@ -1,26 +0,0 @@
|
|||||||
export Input
|
|
||||||
|
|
||||||
typealias Dims{N} NTuple{N,Int}
|
|
||||||
|
|
||||||
dims(d::Dims) = d
|
|
||||||
|
|
||||||
dims(i...) = (i...,)
|
|
||||||
|
|
||||||
type Input{N} <: Model
|
|
||||||
dims::Dims{N}
|
|
||||||
end
|
|
||||||
|
|
||||||
Input(i...) = Input(dims(i...))
|
|
||||||
|
|
||||||
(::Input)(x) = x
|
|
||||||
back!(::Input, ∇, x) = ∇
|
|
||||||
|
|
||||||
shape(i::Input) = i.dims
|
|
||||||
|
|
||||||
# Initialise placeholder
|
|
||||||
|
|
||||||
type Init{F}
|
|
||||||
f::F
|
|
||||||
end
|
|
||||||
|
|
||||||
(f::Init)(args...) = f.f(args...)
|
|
@ -1,23 +0,0 @@
|
|||||||
export Sequence
|
|
||||||
|
|
||||||
type Sequence <: Model
|
|
||||||
layers::Vector{Model}
|
|
||||||
end
|
|
||||||
|
|
||||||
Sequence() = Sequence([])
|
|
||||||
|
|
||||||
@forward Sequence.layers Base.getindex, Base.first, Base.last
|
|
||||||
|
|
||||||
Base.push!(s::Sequence, m::Model) = push!(s.layers, m)
|
|
||||||
|
|
||||||
Base.push!(s::Sequence, f::Init) = push!(s, f(shape(last(s))))
|
|
||||||
|
|
||||||
function Sequence(ms...)
|
|
||||||
s = Sequence()
|
|
||||||
foreach(m -> push!(s, m), ms)
|
|
||||||
return s
|
|
||||||
end
|
|
||||||
|
|
||||||
(s::Sequence)(x) = foldl((x, m) -> m(x), x, s.layers)
|
|
||||||
back!(s::Sequence, ∇) = foldr((m, ∇) -> back!(m, ∇), ∇, s.layers)
|
|
||||||
update!(s::Sequence, η) = foreach(l -> update!(l, η), s.layers)
|
|
42
src/layers/shape.jl
Normal file
42
src/layers/shape.jl
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
export Input
|
||||||
|
|
||||||
|
# Shim for kicking off shape inference
|
||||||
|
|
||||||
|
typealias Dims{N} NTuple{N,Int}
|
||||||
|
|
||||||
|
dims(d::Dims) = d
|
||||||
|
|
||||||
|
dims(i...) = (i...,)
|
||||||
|
|
||||||
|
single(i) = i
|
||||||
|
single(i::Dims) = length(i) == 1 ? first(i) : i
|
||||||
|
|
||||||
|
type Input{N} <: Model
|
||||||
|
dims::Dims{N}
|
||||||
|
end
|
||||||
|
|
||||||
|
Input(i...) = Input(dims(i...))
|
||||||
|
|
||||||
|
(::Input)(x) = x
|
||||||
|
back!(::Input, ∇, x) = ∇
|
||||||
|
|
||||||
|
# Initialise placeholder
|
||||||
|
|
||||||
|
type Init{F}
|
||||||
|
f::F
|
||||||
|
end
|
||||||
|
|
||||||
|
init(i::Init, input...) = i.f(input...)
|
||||||
|
init(m, input...) = m
|
||||||
|
|
||||||
|
# Shape inference API
|
||||||
|
|
||||||
|
shape(x, in) = in
|
||||||
|
|
||||||
|
shape(i::Input, _) = i.dims
|
||||||
|
|
||||||
|
# Implementation for bundled layers
|
||||||
|
|
||||||
|
shape(d::Dense, _) = length(state(d.b)) # TODO: could perhaps infer this
|
||||||
|
|
||||||
|
Dense(out::Integer) = Init(in::Integer -> Dense(in, out))
|
11
src/utils.jl
11
src/utils.jl
@ -5,19 +5,20 @@ const AArray = AbstractArray
|
|||||||
onehot(label, labels) = [i == label for i in labels]
|
onehot(label, labels) = [i == label for i in labels]
|
||||||
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
|
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
|
||||||
|
|
||||||
function train!(m::Model, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
|
||||||
initn(dims...) = randn(dims...)/100
|
initn(dims...) = randn(dims...)/100
|
||||||
|
|
||||||
|
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
||||||
i = 0
|
i = 0
|
||||||
∇ = zeros(length(train[1][2]))
|
∇ = zeros(length(train[1][2]))
|
||||||
for _ in 1:epoch
|
for _ in 1:epoch
|
||||||
for (x, y) in shuffle!(train)
|
for (x, y) in train
|
||||||
i += 1
|
i += 1
|
||||||
err = mse!(∇, m(x), y)
|
pred = m(x)
|
||||||
|
any(isnan, pred) && error("NaN")
|
||||||
|
err = mse!(∇, pred, y)
|
||||||
back!(m, ∇, x)
|
back!(m, ∇, x)
|
||||||
i % batch == 0 && update!(m, η/batch)
|
i % batch == 0 && (update!(m, η); @show accuracy(m, test))
|
||||||
end
|
end
|
||||||
@show accuracy(m, test)
|
|
||||||
end
|
end
|
||||||
return m
|
return m
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user