Merge pull request #44 from ylxdzsw/train-naive

training julia models
This commit is contained in:
Mike J Innes 2017-07-03 18:41:39 +01:00 committed by GitHub
commit 7e4801832b
5 changed files with 66 additions and 3 deletions

View File

@ -6,11 +6,11 @@ module FluxCore
"""
back!(model, ΔY, X...) => ΔX
Backpropagate the gradient `ΔY` through the model `m`, accumulating the
Backpropagate the gradient `ΔY` through the model `model`, accumulating the
gradients of any parameters. Returns the gradient of the input `X`. Gradients
may be arrays or tuples of arrays (for multiple inputs/outputs).
"""
back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(m))")
back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(model))")
"""
update!(model, η) => m

View File

@ -9,3 +9,16 @@ Affine(in::Integer, out::Integer; init = initn) =
inferred(::Type{Affine}, in::Tuple{Dims{2}}, out::Integer) =
Affine(in[1][2], out)
function back!(m::Affine, Δ, x)
W, b = m.W, m.b
W.Δx[:] = x' * Δ
b.Δx[:] = sum(Δ, 1)
Δ * W.x'
end
function update!(m::Affine, η)
update!(m.W, η)
update!(m.b, η)
m
end

View File

@ -7,9 +7,19 @@ end
@forward Chain.layers Base.start, Base.next, Base.done
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
back!(s::Chain, Δ) = foldr((m, Δ) -> back!(m, Δ), Δ, s.layers)
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
function back!(s::Chain, Δ, x)
crumbs = foldl([x], s.layers[1:end-1]) do crumbs, layer
push!(crumbs, layer(crumbs[end]))
end
foldr(Δ, collect(zip(crumbs, s.layers))) do pack, Δ
x, layer = pack
back!(layer, Δ, x)
end
end
graph(s::Chain) =
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)

38
test/optimizer.jl Normal file
View File

@ -0,0 +1,38 @@
@testset "training julia models" begin
@testset "linear regression" begin
srand(0)
model = Affine(10, 1)
truth = Float32[0, 4, 2, 2, -3, 6, -1, 3, 2, -5]'
data = map(1:256) do i
x = rand(Float32, 10)
x, truth * x + 3rand(Float32)
end
Flux.train!(model, data, epoch=5)
@test cor(reshape.((model.W.x, truth), 10)...) > .99
end
@testset "logistic regression" begin
srand(0)
model = Chain(Affine(10, 1), σ)
truth = Float32[0, 4, 2, 2, -3, 6, -1, 3, 2, -5]'
data = map(1:256) do i
x = rand(Float32, 10)
x, truth * x + 2rand(Float32) > 5f0
end
Flux.train!(model, data, epoch=10)
@test cor(reshape.((model.layers[1].W.x, truth), 10)...) > .99
end
end

View File

@ -15,5 +15,7 @@ include("backend/common.jl")
include("basic.jl")
include("recurrent.jl")
include("optimizer.jl")
@tfonly include("backend/tensorflow.jl")
@mxonly include("backend/mxnet.jl")