consistently use delta for gradients
This commit is contained in:
parent
a330b394bd
commit
62fd13bded
@ -1,8 +1,8 @@
|
||||
export mse, mse!
|
||||
|
||||
function mse!(∇, pred, target)
|
||||
map!(-, ∇, pred, target)
|
||||
sumabs2(∇)/2
|
||||
function mse!(Δ, pred, target)
|
||||
map!(-, Δ, pred, target)
|
||||
sumabs2(Δ)/2
|
||||
end
|
||||
|
||||
mse(pred, target) = mse(similar(pred), pred, target)
|
||||
|
@ -23,7 +23,7 @@ end
|
||||
@forward Chain.layers Base.getindex, Base.first, Base.last
|
||||
|
||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||
back!(s::Chain, ∇) = foldr((m, ∇) -> back!(m, ∇), ∇, s.layers)
|
||||
back!(s::Chain, Δ) = foldr((m, Δ) -> back!(m, Δ), Δ, s.layers)
|
||||
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
|
||||
|
||||
graph(s::Chain) =
|
||||
|
@ -23,7 +23,7 @@ end
|
||||
Input(i...) = Input(dims(i...))
|
||||
|
||||
(::Input)(x) = x
|
||||
back!(::Input, ∇, x) = ∇
|
||||
back!(::Input, Δ, x) = Δ
|
||||
|
||||
# Initialise placeholder
|
||||
|
||||
|
@ -4,7 +4,7 @@ export Model, back!, update!, param
|
||||
|
||||
abstract Model
|
||||
|
||||
back!(m::Model, ∇) = error("Backprop not implemented for $(typeof(m))")
|
||||
back!(m::Model, Δ) = error("Backprop not implemented for $(typeof(m))")
|
||||
update!(m, η) = m
|
||||
|
||||
graph(m) = nothing
|
||||
|
@ -6,14 +6,14 @@ initn(dims...) = randn(Float32, dims...)/10
|
||||
|
||||
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
||||
i = 0
|
||||
∇ = zeros(length(train[1][2]))
|
||||
Δ = zeros(length(train[1][2]))
|
||||
for _ in 1:epoch
|
||||
@progress for (x, y) in train
|
||||
i += 1
|
||||
pred = m(x)
|
||||
any(isnan, pred) && error("NaN")
|
||||
err = mse!(∇, pred, y)
|
||||
back!(m, ∇, x)
|
||||
err = mse!(Δ, pred, y)
|
||||
back!(m, Δ, x)
|
||||
i % batch == 0 && update!(m, η)
|
||||
i % 1000 == 0 && @show accuracy(m, test)
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user