From 735b970c12b9e8c5cd4b8010c04e84f814794dd9 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 10 Jan 2019 10:19:05 +0000 Subject: [PATCH] fix update for scalars --- src/tracker/Tracker.jl | 6 ------ src/tracker/lib/array.jl | 6 ++++++ src/tracker/lib/real.jl | 8 +++++++- test/tracker.jl | 9 +++++++++ 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 3f059926..010f9f4f 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -61,12 +61,6 @@ macro grad(ex) @q(Tracker._forward($(args...)) where $(T...) = $body) |> esc end -function update!(x, Δ) - x.data .+= data(Δ) - tracker(x).grad .= 0 - return x -end - include("idset.jl") include("back.jl") include("numeric.jl") diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index a94323ca..08a40db7 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -65,6 +65,12 @@ Base.setindex!(xs::TrackedArray, v, i...) = back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`") +function update!(x::TrackedArray, Δ) + x.data .+= data(Δ) + tracker(x).grad .= 0 + return x +end + # Fallthrough methods for f in :[Base.size, Base.ndims, Base.collect].args diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl index 146706c7..6e7a44f2 100644 --- a/src/tracker/lib/real.jl +++ b/src/tracker/lib/real.jl @@ -1,4 +1,4 @@ -struct TrackedReal{T<:Real} <: Real +mutable struct TrackedReal{T<:Real} <: Real data::T tracker::Tracked{T} end @@ -16,6 +16,12 @@ function back!(x::TrackedReal; once = true) return back!(x, 1, once = once) end +function update!(x::TrackedReal, Δ) + x.data += data(Δ) + tracker(x).grad = 0 + return x +end + function Base.show(io::IO, x::TrackedReal) T = get(io, :typeinfo, Any) show(io, data(x)) diff --git a/test/tracker.jl b/test/tracker.jl index 51f4ad96..b4eab012 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -286,4 +286,13 @@ end @test count == 3 end +@testset "Updates" begin + xs = param([1, 2, 3]) + Tracker.update!(xs, param([4, 5, 6])) + @test xs == [5, 7, 9] + x = param(3) + Tracker.update!(x, param(4)) + @test x == 7 +end + end #testset