grad refactor
This commit is contained in:
parent
3ef72a9d7b
commit
7cfc42d166
@ -1,5 +1,7 @@
|
||||
module Tracker
|
||||
|
||||
using Base: RefValue
|
||||
|
||||
export track, back!
|
||||
|
||||
data(x) = x
|
||||
@ -14,13 +16,10 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
||||
|
||||
(c::Call)() = c.func(data.(c.args)...)
|
||||
|
||||
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
|
||||
back!(::Call{Void}, Δ) = nothing
|
||||
|
||||
struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
||||
f::Call
|
||||
x::A
|
||||
Δ::A
|
||||
data::A
|
||||
grad::RefValue{A}
|
||||
end
|
||||
|
||||
TrackedScalar{T,A} = TrackedArray{T,0,A}
|
||||
@ -28,36 +27,19 @@ TrackedVector{T,A} = TrackedArray{T,1,A}
|
||||
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
||||
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
||||
|
||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray(c::Call, x::A, Δ::Ref{A}) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
|
||||
|
||||
TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, zeros(x))
|
||||
TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, RefValue{typeof(x)}())
|
||||
|
||||
TrackedArray(c::Call) = TrackedArray(c, c())
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x)
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, RefValue(zeros(x)))
|
||||
|
||||
track(xs) = TrackedArray(xs)
|
||||
track(xs) = TrackedArray(AbstractFloat.(xs))
|
||||
istracked(x::TrackedArray) = true
|
||||
data(x::TrackedArray) = x.x
|
||||
grad(x::TrackedArray) = x.Δ
|
||||
|
||||
tovec(xs::AbstractArray) = vec(xs)
|
||||
tovec(xs) = xs
|
||||
|
||||
function back!(x::TrackedArray, Δ)
|
||||
x.Δ .+= Δ
|
||||
back!(x.f, Δ)
|
||||
end
|
||||
|
||||
back!(x::TrackedScalar) = back!(x, 1)
|
||||
|
||||
macro back!(x, Δ)
|
||||
quote
|
||||
x = $(esc(x))
|
||||
istracked(x) && back!(x, $(esc(Δ)))
|
||||
end
|
||||
end
|
||||
data(x::TrackedArray) = x.data
|
||||
grad(x::TrackedArray) = x.grad[]
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
@ -84,6 +66,7 @@ function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = tru
|
||||
end
|
||||
end
|
||||
|
||||
include("back.jl")
|
||||
include("lib.jl")
|
||||
include("numeric.jl")
|
||||
|
||||
@ -91,7 +74,7 @@ using Requires
|
||||
|
||||
@require CuArrays begin
|
||||
import CuArrays: cu
|
||||
cu(xs::TrackedArray) = TrackedArray(xs.f, cu(xs.x), cu(xs.Δ))
|
||||
cu(xs::TrackedArray) = TrackedArray(xs.f, cu(xs.data), RefValue(cu(grad(xs))))
|
||||
end
|
||||
|
||||
end
|
||||
|
16
src/tracker/back.jl
Normal file
16
src/tracker/back.jl
Normal file
@ -0,0 +1,16 @@
|
||||
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
|
||||
back!(::Call{Void}, Δ) = nothing
|
||||
|
||||
function back!(x::TrackedArray, Δ)
|
||||
isassigned(x.grad) && (x.grad[] .+= Δ)
|
||||
back!(x.f, Δ)
|
||||
end
|
||||
|
||||
back!(x::TrackedScalar) = back!(x, 1)
|
||||
|
||||
macro back!(x, Δ)
|
||||
quote
|
||||
x = $(esc(x))
|
||||
istracked(x) && back!(x, $(esc(Δ)))
|
||||
end
|
||||
end
|
@ -7,10 +7,10 @@ unarray(xs) = xs
|
||||
unarray(xs::AbstractArray{T,0} where T) = xs[]
|
||||
|
||||
Base.getindex(xs::TrackedArray, i...) =
|
||||
TrackedArray(Call(getindex, xs, i...), toarray(xs.x, xs.x[i...]))
|
||||
TrackedArray(Call(getindex, xs, i...), toarray(xs.data, xs.data[i...]))
|
||||
|
||||
function back!(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
||||
Δ′ = zeros(xs.x)
|
||||
Δ′ = zeros(xs.data)
|
||||
Δ′[i...] = unarray(Δ)
|
||||
@back!(xs, Δ′)
|
||||
end
|
||||
@ -49,13 +49,13 @@ end
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
||||
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.x, sum(xs.x)))
|
||||
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data)))
|
||||
Base.sum(xs::TrackedScalar, dim...) = xs
|
||||
|
||||
back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.x) .= Δ)
|
||||
back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.data) .= Δ)
|
||||
|
||||
Base.maximum(xs::TrackedArray, args...) = maximum(xs.x, args...)
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.x, args...)
|
||||
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
# BLAS
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user