commit
5caeeccb5f
|
@ -61,12 +61,6 @@ macro grad(ex)
|
|||
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
|
||||
end
|
||||
|
||||
function update!(x, Δ)
|
||||
x.data .+= data(Δ)
|
||||
tracker(x).grad .= 0
|
||||
return x
|
||||
end
|
||||
|
||||
include("idset.jl")
|
||||
include("back.jl")
|
||||
include("numeric.jl")
|
||||
|
|
|
@ -65,6 +65,12 @@ Base.setindex!(xs::TrackedArray, v, i...) =
|
|||
|
||||
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
|
||||
|
||||
function update!(x::TrackedArray, Δ)
|
||||
x.data .+= data(Δ)
|
||||
tracker(x).grad .= 0
|
||||
return x
|
||||
end
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims, Base.collect].args
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
struct TrackedReal{T<:Real} <: Real
|
||||
mutable struct TrackedReal{T<:Real} <: Real
|
||||
data::T
|
||||
tracker::Tracked{T}
|
||||
end
|
||||
|
@ -16,6 +16,12 @@ function back!(x::TrackedReal; once = true)
|
|||
return back!(x, 1, once = once)
|
||||
end
|
||||
|
||||
function update!(x::TrackedReal, Δ)
|
||||
x.data += data(Δ)
|
||||
tracker(x).grad = 0
|
||||
return x
|
||||
end
|
||||
|
||||
function Base.show(io::IO, x::TrackedReal)
|
||||
T = get(io, :typeinfo, Any)
|
||||
show(io, data(x))
|
||||
|
|
|
@ -286,4 +286,13 @@ end
|
|||
@test count == 3
|
||||
end
|
||||
|
||||
@testset "Updates" begin
|
||||
xs = param([1, 2, 3])
|
||||
Tracker.update!(xs, param([4, 5, 6]))
|
||||
@test xs == [5, 7, 9]
|
||||
x = param(3)
|
||||
Tracker.update!(x, param(4))
|
||||
@test x == 7
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue