fix update for scalars

This commit is contained in:
Mike J Innes 2019-01-10 10:19:05 +00:00
parent 9781f063aa
commit 735b970c12
4 changed files with 22 additions and 7 deletions

View File

@ -61,12 +61,6 @@ macro grad(ex)
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
end
function update!(x, Δ)
x.data .+= data(Δ)
tracker(x).grad .= 0
return x
end
include("idset.jl")
include("back.jl")
include("numeric.jl")

View File

@ -65,6 +65,12 @@ Base.setindex!(xs::TrackedArray, v, i...) =
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
function update!(x::TrackedArray, Δ)
x.data .+= data(Δ)
tracker(x).grad .= 0
return x
end
# Fallthrough methods
for f in :[Base.size, Base.ndims, Base.collect].args

View File

@ -1,4 +1,4 @@
struct TrackedReal{T<:Real} <: Real
mutable struct TrackedReal{T<:Real} <: Real
data::T
tracker::Tracked{T}
end
@ -16,6 +16,12 @@ function back!(x::TrackedReal; once = true)
return back!(x, 1, once = once)
end
function update!(x::TrackedReal, Δ)
x.data += data(Δ)
tracker(x).grad = 0
return x
end
function Base.show(io::IO, x::TrackedReal)
T = get(io, :typeinfo, Any)
show(io, data(x))

View File

@ -286,4 +286,13 @@ end
@test count == 3
end
@testset "Updates" begin
xs = param([1, 2, 3])
Tracker.update!(xs, param([4, 5, 6]))
@test xs == [5, 7, 9]
x = param(3)
Tracker.update!(x, param(4))
@test x == 7
end
end #testset