consistently use delta for gradients
This commit is contained in:
parent
a330b394bd
commit
62fd13bded
@ -1,8 +1,8 @@
|
|||||||
export mse, mse!
|
export mse, mse!
|
||||||
|
|
||||||
function mse!(∇, pred, target)
|
function mse!(Δ, pred, target)
|
||||||
map!(-, ∇, pred, target)
|
map!(-, Δ, pred, target)
|
||||||
sumabs2(∇)/2
|
sumabs2(Δ)/2
|
||||||
end
|
end
|
||||||
|
|
||||||
mse(pred, target) = mse(similar(pred), pred, target)
|
mse(pred, target) = mse(similar(pred), pred, target)
|
||||||
|
@ -23,7 +23,7 @@ end
|
|||||||
@forward Chain.layers Base.getindex, Base.first, Base.last
|
@forward Chain.layers Base.getindex, Base.first, Base.last
|
||||||
|
|
||||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||||
back!(s::Chain, ∇) = foldr((m, ∇) -> back!(m, ∇), ∇, s.layers)
|
back!(s::Chain, Δ) = foldr((m, Δ) -> back!(m, Δ), Δ, s.layers)
|
||||||
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
|
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
|
||||||
|
|
||||||
graph(s::Chain) =
|
graph(s::Chain) =
|
||||||
|
@ -23,7 +23,7 @@ end
|
|||||||
Input(i...) = Input(dims(i...))
|
Input(i...) = Input(dims(i...))
|
||||||
|
|
||||||
(::Input)(x) = x
|
(::Input)(x) = x
|
||||||
back!(::Input, ∇, x) = ∇
|
back!(::Input, Δ, x) = Δ
|
||||||
|
|
||||||
# Initialise placeholder
|
# Initialise placeholder
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ export Model, back!, update!, param
|
|||||||
|
|
||||||
abstract Model
|
abstract Model
|
||||||
|
|
||||||
back!(m::Model, ∇) = error("Backprop not implemented for $(typeof(m))")
|
back!(m::Model, Δ) = error("Backprop not implemented for $(typeof(m))")
|
||||||
update!(m, η) = m
|
update!(m, η) = m
|
||||||
|
|
||||||
graph(m) = nothing
|
graph(m) = nothing
|
||||||
|
@ -6,14 +6,14 @@ initn(dims...) = randn(Float32, dims...)/10
|
|||||||
|
|
||||||
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
||||||
i = 0
|
i = 0
|
||||||
∇ = zeros(length(train[1][2]))
|
Δ = zeros(length(train[1][2]))
|
||||||
for _ in 1:epoch
|
for _ in 1:epoch
|
||||||
@progress for (x, y) in train
|
@progress for (x, y) in train
|
||||||
i += 1
|
i += 1
|
||||||
pred = m(x)
|
pred = m(x)
|
||||||
any(isnan, pred) && error("NaN")
|
any(isnan, pred) && error("NaN")
|
||||||
err = mse!(∇, pred, y)
|
err = mse!(Δ, pred, y)
|
||||||
back!(m, ∇, x)
|
back!(m, Δ, x)
|
||||||
i % batch == 0 && update!(m, η)
|
i % batch == 0 && update!(m, η)
|
||||||
i % 1000 == 0 && @show accuracy(m, test)
|
i % 1000 == 0 && @show accuracy(m, test)
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user