loss function gradients

This commit is contained in:
Mike J Innes 2017-08-23 17:50:43 +01:00
parent 60c3090981
commit e4e9794f5e
4 changed files with 12 additions and 2 deletions

View File

@ -1,6 +1,6 @@
# Cost functions
mse(, y) = sumabs2( .- y)/2
mse(, y) = sum(( .- y).^2)/2
# back!(::typeof(mse), Δ, ŷ, y) = Δ .* (ŷ .- y)
logloss(, y) = -sum(y .* log.())

View File

@ -41,9 +41,12 @@ istracked(x::TrackedArray) = true
data(x::TrackedArray) = x.x
grad(x::TrackedArray) = x.Δ
tovec(xs::AbstractArray) = vec(xs)
tovec(xs) = xs
function back!(x::TrackedArray, Δ)
Δ′ = vec(x.Δ)
Δ′ .+= vec(Δ)
Δ′ .+= tovec(Δ)
back!(x.f, Δ)
end

View File

@ -14,10 +14,13 @@ end
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
back!(::typeof(-), Δ, xs::TrackedArray) = back!(xs, -Δ)
# Reductions
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.x, sum(xs.x)))
Base.sum(xs::TrackedScalar, dim...) = xs
back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.x) .= Δ)
@ -76,6 +79,7 @@ function back!(b::Broadcasted, Δ, args...)
end
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{TrackedArray}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ct) = TrackedArray

View File

@ -14,4 +14,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
gradtest(x -> softmax(x).*(1:3), 3)
gradtest(x -> softmax(x).*(1:3), (3,5))
gradtest(Flux.mse, rand(5,5), rand(5, 5))
gradtest(Flux.logloss, rand(5,5), rand(5, 5))
end