loss function gradients
This commit is contained in:
parent
60c3090981
commit
e4e9794f5e
@ -1,6 +1,6 @@
|
||||
# Cost functions
|
||||
|
||||
mse(ŷ, y) = sumabs2(ŷ .- y)/2
|
||||
mse(ŷ, y) = sum((ŷ .- y).^2)/2
|
||||
# back!(::typeof(mse), Δ, ŷ, y) = Δ .* (ŷ .- y)
|
||||
|
||||
logloss(ŷ, y) = -sum(y .* log.(ŷ))
|
||||
|
@ -41,9 +41,12 @@ istracked(x::TrackedArray) = true
|
||||
data(x::TrackedArray) = x.x
|
||||
grad(x::TrackedArray) = x.Δ
|
||||
|
||||
tovec(xs::AbstractArray) = vec(xs)
|
||||
tovec(xs) = xs
|
||||
|
||||
function back!(x::TrackedArray, Δ)
|
||||
Δ′ = vec(x.Δ)
|
||||
Δ′ .+= vec(Δ)
|
||||
Δ′ .+= tovec(Δ)
|
||||
back!(x.f, Δ)
|
||||
end
|
||||
|
||||
|
@ -14,10 +14,13 @@ end
|
||||
|
||||
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
|
||||
|
||||
back!(::typeof(-), Δ, xs::TrackedArray) = back!(xs, -Δ)
|
||||
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
||||
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.x, sum(xs.x)))
|
||||
Base.sum(xs::TrackedScalar, dim...) = xs
|
||||
|
||||
back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.x) .= Δ)
|
||||
|
||||
@ -76,6 +79,7 @@ function back!(b::Broadcasted, Δ, args...)
|
||||
end
|
||||
|
||||
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
|
||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{TrackedArray}) = TrackedArray
|
||||
Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray
|
||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray
|
||||
Base.Broadcast.promote_containertype(::Type{TrackedArray}, ct) = TrackedArray
|
||||
|
@ -14,4 +14,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
gradtest(x -> softmax(x).*(1:3), 3)
|
||||
gradtest(x -> softmax(x).*(1:3), (3,5))
|
||||
|
||||
gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
||||
gradtest(Flux.logloss, rand(5,5), rand(5, 5))
|
||||
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user