From e4e9794f5e49872e087e238e32ef1a86c0501bfa Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 23 Aug 2017 17:50:43 +0100 Subject: [PATCH] loss function gradients --- src/layers/stateless.jl | 2 +- src/tracker/Tracker.jl | 5 ++++- src/tracker/lib.jl | 4 ++++ test/tracker.jl | 3 +++ 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index f53de2dd..52db7c7c 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -1,6 +1,6 @@ # Cost functions -mse(ŷ, y) = sumabs2(ŷ .- y)/2 +mse(ŷ, y) = sum((ŷ .- y).^2)/2 # back!(::typeof(mse), Δ, ŷ, y) = Δ .* (ŷ .- y) logloss(ŷ, y) = -sum(y .* log.(ŷ)) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index f1fcf37e..38374e3d 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -41,9 +41,12 @@ istracked(x::TrackedArray) = true data(x::TrackedArray) = x.x grad(x::TrackedArray) = x.Δ +tovec(xs::AbstractArray) = vec(xs) +tovec(xs) = xs + function back!(x::TrackedArray, Δ) Δ′ = vec(x.Δ) - Δ′ .+= vec(Δ) + Δ′ .+= tovec(Δ) back!(x.f, Δ) end diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 28929d73..76f420d3 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -14,10 +14,13 @@ end Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs)) +back!(::typeof(-), Δ, xs::TrackedArray) = back!(xs, -Δ) + # Reductions Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim)) Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.x, sum(xs.x))) +Base.sum(xs::TrackedScalar, dim...) = xs back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.x) .= Δ) @@ -76,6 +79,7 @@ function back!(b::Broadcasted, Δ, args...) end Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray +Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{TrackedArray}) = TrackedArray Base.Broadcast.promote_containertype(::Type{Array}, ::Type{TrackedArray}) = TrackedArray Base.Broadcast.promote_containertype(::Type{TrackedArray}, ::Type{Array}) = TrackedArray Base.Broadcast.promote_containertype(::Type{TrackedArray}, ct) = TrackedArray diff --git a/test/tracker.jl b/test/tracker.jl index 3ba4a93e..1e64b23a 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -14,4 +14,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) gradtest(x -> softmax(x).*(1:3), 3) gradtest(x -> softmax(x).*(1:3), (3,5)) +gradtest(Flux.mse, rand(5,5), rand(5, 5)) +gradtest(Flux.logloss, rand(5,5), rand(5, 5)) + end