diff --git a/src/Tracker/Tracker.jl b/src/Tracker/Tracker.jl index 24a1950c..3eaa27fb 100644 --- a/src/Tracker/Tracker.jl +++ b/src/Tracker/Tracker.jl @@ -15,40 +15,41 @@ back!(c::Call, Δ) = back!(c.func, Δ, c.args...) back!(f, Δ) = nothing -struct Var{T,N,A} <: AbstractArray{T,N} +struct TrackedArray{T,N,A} <: AbstractArray{T,N} f::Call x::A Δ::A end -ScalarVar{T,A} = Var{T,0,A} -VectorVar{T,A} = Var{T,1,A} -MatrixVar{T,A} = Var{T,2,A} +TrackedScalar{T,A} = TrackedArray{T,0,A} +TrackedVector{T,A} = TrackedArray{T,1,A} +TrackedMatrix{T,A} = TrackedArray{T,2,A} -Var(c::Call, x::A, Δ::A) where A <: AbstractArray = - Var{eltype(A),ndims(A),A}(c, x, Δ) +TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = + TrackedArray{eltype(A),ndims(A),A}(c, x, Δ) -Var(c::Call, x::AbstractArray) = Var(c, x, zeros(x)) +TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, zeros(x)) -Var(c::Call) = Var(c, c()) +TrackedArray(c::Call) = TrackedArray(c, c()) -Var(x::AbstractArray) = Var(Call(nothing), x) +TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x) -data(x::Var) = x.x -grad(x::Var) = x.Δ +track(xs) = TrackedArray(xs) +data(x::TrackedArray) = x.x +grad(x::TrackedArray) = x.Δ -function back!(x::Var, Δ) +function back!(x::TrackedArray, Δ) x.Δ .+= Δ back!(x.f, Δ) end for f in :[Base.size, Base.ndims, Base.similar].args - @eval @inline $f(x::Var, a...) = $f(data(x), a...) + @eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...) end -function Base.showarray(io::IO, X::Var, repr::Bool = true; header = true) +function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true) if repr - print(io, "Var(") + print(io, "TrackedArray(") Base.showarray(io, data(X), true) print(io, ")") else diff --git a/src/Tracker/lib.jl b/src/Tracker/lib.jl index a1e35c67..00fbb407 100644 --- a/src/Tracker/lib.jl +++ b/src/Tracker/lib.jl @@ -1,7 +1,7 @@ import Base: * -a::MatrixVar * b::Union{MatrixVar,AbstractMatrix} = Var(Call(*, a, b)) -a::Union{MatrixVar,AbstractMatrix} * b::MatrixVar = Var(Call(*, a, b)) +a::TrackedMatrix * b::Union{TrackedMatrix,AbstractMatrix} = Var(Call(*, a, b)) +a::Union{TrackedMatrix,AbstractMatrix} * b::TrackedMatrix = Var(Call(*, a, b)) function back!(::typeof(*), Δ, a::AbstractArray, b::AbstractArray) back!(a, A_mul_Bt(Δ, data(b)))