grad refactor

This commit is contained in:
Mike J Innes 2017-09-06 21:21:35 -04:00
parent 3ef72a9d7b
commit 7cfc42d166
3 changed files with 34 additions and 35 deletions

View File

@ -1,5 +1,7 @@
module Tracker
using Base: RefValue
export track, back!
data(x) = x
@ -14,13 +16,10 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
(c::Call)() = c.func(data.(c.args)...)
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
back!(::Call{Void}, Δ) = nothing
struct TrackedArray{T,N,A} <: AbstractArray{T,N}
f::Call
x::A
Δ::A
data::A
grad::RefValue{A}
end
TrackedScalar{T,A} = TrackedArray{T,0,A}
@ -28,36 +27,19 @@ TrackedVector{T,A} = TrackedArray{T,1,A}
TrackedMatrix{T,A} = TrackedArray{T,2,A}
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray(c::Call, x::A, Δ::Ref{A}) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, zeros(x))
TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, RefValue{typeof(x)}())
TrackedArray(c::Call) = TrackedArray(c, c())
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x)
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, RefValue(zeros(x)))
track(xs) = TrackedArray(xs)
track(xs) = TrackedArray(AbstractFloat.(xs))
istracked(x::TrackedArray) = true
data(x::TrackedArray) = x.x
grad(x::TrackedArray) = x.Δ
tovec(xs::AbstractArray) = vec(xs)
tovec(xs) = xs
function back!(x::TrackedArray, Δ)
x.Δ .+= Δ
back!(x.f, Δ)
end
back!(x::TrackedScalar) = back!(x, 1)
macro back!(x, Δ)
quote
x = $(esc(x))
istracked(x) && back!(x, $(esc(Δ)))
end
end
data(x::TrackedArray) = x.data
grad(x::TrackedArray) = x.grad[]
# Fallthrough methods
@ -84,6 +66,7 @@ function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = tru
end
end
include("back.jl")
include("lib.jl")
include("numeric.jl")
@ -91,7 +74,7 @@ using Requires
@require CuArrays begin
import CuArrays: cu
cu(xs::TrackedArray) = TrackedArray(xs.f, cu(xs.x), cu(xs.Δ))
cu(xs::TrackedArray) = TrackedArray(xs.f, cu(xs.data), RefValue(cu(grad(xs))))
end
end

16
src/tracker/back.jl Normal file
View File

@ -0,0 +1,16 @@
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
back!(::Call{Void}, Δ) = nothing
function back!(x::TrackedArray, Δ)
isassigned(x.grad) && (x.grad[] .+= Δ)
back!(x.f, Δ)
end
back!(x::TrackedScalar) = back!(x, 1)
macro back!(x, Δ)
quote
x = $(esc(x))
istracked(x) && back!(x, $(esc(Δ)))
end
end

View File

@ -7,10 +7,10 @@ unarray(xs) = xs
unarray(xs::AbstractArray{T,0} where T) = xs[]
Base.getindex(xs::TrackedArray, i...) =
TrackedArray(Call(getindex, xs, i...), toarray(xs.x, xs.x[i...]))
TrackedArray(Call(getindex, xs, i...), toarray(xs.data, xs.data[i...]))
function back!(::typeof(getindex), Δ, xs::TrackedArray, i...)
Δ′ = zeros(xs.x)
Δ′ = zeros(xs.data)
Δ′[i...] = unarray(Δ)
@back!(xs, Δ′)
end
@ -49,13 +49,13 @@ end
# Reductions
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.x, sum(xs.x)))
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data)))
Base.sum(xs::TrackedScalar, dim...) = xs
back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.x) .= Δ)
back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.data) .= Δ)
Base.maximum(xs::TrackedArray, args...) = maximum(xs.x, args...)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.x, args...)
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
# BLAS