commit
5caeeccb5f
|
@ -61,12 +61,6 @@ macro grad(ex)
|
||||||
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
|
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
|
||||||
end
|
end
|
||||||
|
|
||||||
function update!(x, Δ)
|
|
||||||
x.data .+= data(Δ)
|
|
||||||
tracker(x).grad .= 0
|
|
||||||
return x
|
|
||||||
end
|
|
||||||
|
|
||||||
include("idset.jl")
|
include("idset.jl")
|
||||||
include("back.jl")
|
include("back.jl")
|
||||||
include("numeric.jl")
|
include("numeric.jl")
|
||||||
|
|
|
@ -65,6 +65,12 @@ Base.setindex!(xs::TrackedArray, v, i...) =
|
||||||
|
|
||||||
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
|
back!(::TrackedArray) = error("Value is not scalar; use `back!(sum(x))` or `back!(x, Δ)`")
|
||||||
|
|
||||||
|
function update!(x::TrackedArray, Δ)
|
||||||
|
x.data .+= data(Δ)
|
||||||
|
tracker(x).grad .= 0
|
||||||
|
return x
|
||||||
|
end
|
||||||
|
|
||||||
# Fallthrough methods
|
# Fallthrough methods
|
||||||
|
|
||||||
for f in :[Base.size, Base.ndims, Base.collect].args
|
for f in :[Base.size, Base.ndims, Base.collect].args
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
struct TrackedReal{T<:Real} <: Real
|
mutable struct TrackedReal{T<:Real} <: Real
|
||||||
data::T
|
data::T
|
||||||
tracker::Tracked{T}
|
tracker::Tracked{T}
|
||||||
end
|
end
|
||||||
|
@ -16,6 +16,12 @@ function back!(x::TrackedReal; once = true)
|
||||||
return back!(x, 1, once = once)
|
return back!(x, 1, once = once)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function update!(x::TrackedReal, Δ)
|
||||||
|
x.data += data(Δ)
|
||||||
|
tracker(x).grad = 0
|
||||||
|
return x
|
||||||
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, x::TrackedReal)
|
function Base.show(io::IO, x::TrackedReal)
|
||||||
T = get(io, :typeinfo, Any)
|
T = get(io, :typeinfo, Any)
|
||||||
show(io, data(x))
|
show(io, data(x))
|
||||||
|
|
|
@ -286,4 +286,13 @@ end
|
||||||
@test count == 3
|
@test count == 3
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Updates" begin
|
||||||
|
xs = param([1, 2, 3])
|
||||||
|
Tracker.update!(xs, param([4, 5, 6]))
|
||||||
|
@test xs == [5, 7, 9]
|
||||||
|
x = param(3)
|
||||||
|
Tracker.update!(x, param(4))
|
||||||
|
@test x == 7
|
||||||
|
end
|
||||||
|
|
||||||
end #testset
|
end #testset
|
||||||
|
|
Loading…
Reference in New Issue