From 62fd13bded45982f6a47a219327987f3317e3868 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 15 Dec 2016 21:37:39 +0000 Subject: [PATCH] consistently use delta for gradients --- src/cost.jl | 6 +++--- src/layers/chain.jl | 2 +- src/layers/shape.jl | 2 +- src/model.jl | 2 +- src/utils.jl | 6 +++--- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/cost.jl b/src/cost.jl index 62772894..34267202 100644 --- a/src/cost.jl +++ b/src/cost.jl @@ -1,8 +1,8 @@ export mse, mse! -function mse!(∇, pred, target) - map!(-, ∇, pred, target) - sumabs2(∇)/2 +function mse!(Δ, pred, target) + map!(-, Δ, pred, target) + sumabs2(Δ)/2 end mse(pred, target) = mse(similar(pred), pred, target) diff --git a/src/layers/chain.jl b/src/layers/chain.jl index cb47b545..b4bd6ced 100644 --- a/src/layers/chain.jl +++ b/src/layers/chain.jl @@ -23,7 +23,7 @@ end @forward Chain.layers Base.getindex, Base.first, Base.last (s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) -back!(s::Chain, ∇) = foldr((m, ∇) -> back!(m, ∇), ∇, s.layers) +back!(s::Chain, Δ) = foldr((m, Δ) -> back!(m, Δ), Δ, s.layers) update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers) graph(s::Chain) = diff --git a/src/layers/shape.jl b/src/layers/shape.jl index 41c1d97b..5fe47cd9 100644 --- a/src/layers/shape.jl +++ b/src/layers/shape.jl @@ -23,7 +23,7 @@ end Input(i...) = Input(dims(i...)) (::Input)(x) = x -back!(::Input, ∇, x) = ∇ +back!(::Input, Δ, x) = Δ # Initialise placeholder diff --git a/src/model.jl b/src/model.jl index cdbbd21f..75d9f906 100644 --- a/src/model.jl +++ b/src/model.jl @@ -4,7 +4,7 @@ export Model, back!, update!, param abstract Model -back!(m::Model, ∇) = error("Backprop not implemented for $(typeof(m))") +back!(m::Model, Δ) = error("Backprop not implemented for $(typeof(m))") update!(m, η) = m graph(m) = nothing diff --git a/src/utils.jl b/src/utils.jl index d53bd5d5..c685fd4e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,14 +6,14 @@ initn(dims...) = randn(Float32, dims...)/10 function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1) i = 0 - ∇ = zeros(length(train[1][2])) + Δ = zeros(length(train[1][2])) for _ in 1:epoch @progress for (x, y) in train i += 1 pred = m(x) any(isnan, pred) && error("NaN") - err = mse!(∇, pred, y) - back!(m, ∇, x) + err = mse!(Δ, pred, y) + back!(m, Δ, x) i % batch == 0 && update!(m, η) i % 1000 == 0 && @show accuracy(m, test) end