shave some memory

This commit is contained in:
Mike J Innes 2018-07-09 19:44:14 +01:00
parent 1430053b69
commit e763c342ee
4 changed files with 33 additions and 35 deletions

View File

@ -11,9 +11,9 @@ tracker(x) = nothing
istracked(x) = tracker(x) nothing istracked(x) = tracker(x) nothing
isleaf(x) = !istracked(x) || isleaf(tracker(x)) isleaf(x) = !istracked(x) || isleaf(tracker(x))
data(x) = istracked(x) ? data(tracker(x)) : x
grad(x) = grad(tracker(x)) grad(x) = grad(tracker(x))
grad(::Void) = nothing grad(::Void) = nothing
data(x) = x
struct Call{F,As<:Tuple} struct Call{F,As<:Tuple}
func::F func::F
@ -32,29 +32,23 @@ mutable struct Tracked{T}
ref::UInt32 ref::UInt32
f::Call f::Call
isleaf::Bool isleaf::Bool
data::T
grad::T grad::T
Tracked{T}(f::Call, data::T) where T = new(0, f, false, data) Tracked{T}(f::Call) where T = new(0, f, false)
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad) Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad) Tracked{T}(f::Call{Void}, grad::T) where T = new(0, f, true, grad)
end end
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ)
istracked(x::Tracked) = true istracked(x::Tracked) = true
isleaf(x::Tracked) = x.f == Call() isleaf(x::Tracked) = x.f == Call()
data(x::Tracked) = x.data
grad(x::Tracked) = x.grad grad(x::Tracked) = x.grad
track(f::Call, x) = Tracked(f, x) track(f::Call, x) = Tracked{typeof(x)}(f)
track(f::Call) = track(f, f())
function _forward end function _forward end
function track(f, xs...; kw...) function track(f, xs...; kw...)
y, back = _forward(f, data.(xs)...; kw...) y, back = _forward(f, data.(xs)...; kw...)
track(Call(back, xs), y) track(Call(back, tracker.(xs)), y)
end end
macro grad(ex) macro grad(ex)

View File

@ -6,6 +6,7 @@ struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad) TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
end end
data(x::TrackedArray) = x.data
tracker(x::TrackedArray) = x.tracker tracker(x::TrackedArray) = x.tracker
TrackedVector{T,A} = TrackedArray{T,1,A} TrackedVector{T,A} = TrackedArray{T,1,A}
@ -15,10 +16,10 @@ TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
track(c::Call, x::AbstractArray) = TrackedArray(c, x) track(c::Call, x::AbstractArray) = TrackedArray(c, x)
TrackedArray(c::Call, x::A) where A <: AbstractArray = TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x) TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x)
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ) TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x)) TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
@ -369,7 +370,7 @@ function ∇broadcast(f, args::Vararg{Any,N}) where N
map((x, Δ) -> unbroadcast(x, Δ), args, Δargs) map((x, Δ) -> unbroadcast(x, Δ), args, Δargs)
end end
# So we can return non-tracked arrays # So we can return non-tracked arrays
track(Call(back, args), y) track(Call(back, tracker.(args)), y)
end end
Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray

View File

@ -10,8 +10,6 @@ function scan(x::Tracked)
if ref == 1 if ref == 1
scan(x.f) scan(x.f)
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad)) isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
else
isdefined(x, :grad) || (x.grad = init_grad(x.data))
end end
return return
end end
@ -25,7 +23,7 @@ function back_(c::Call, Δ)
Δs = c.func(Δ) Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) || (Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))") error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args, Δs) foreach(back, c.args, Δs)
end end
back_(::Call{Void}, Δ) = nothing back_(::Call{Void}, Δ) = nothing
@ -36,8 +34,12 @@ accum!(x::AbstractArray, Δ) = (x .+= Δ)
function back(x::Tracked, Δ) function back(x::Tracked, Δ)
x.isleaf && (x.grad = accum!(x.grad, Δ); return) x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1 ref = x.ref -= 1
if ref > 0 || isdefined(x, :grad)
if isdefined(x, :grad) if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ) x.grad = accum!(x.grad, Δ)
else
x.grad = Δ
end
ref == 0 && back_(x.f, x.grad) ref == 0 && back_(x.f, x.grad)
else else
ref == 0 && back_(x.f, Δ) ref == 0 && back_(x.f, Δ)
@ -45,8 +47,7 @@ function back(x::Tracked, Δ)
return return
end end
back(x, Δ) = back(tracker(x), Δ) back(::Void, _) = return
back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
# Interface methods # Interface methods
@ -55,13 +56,13 @@ back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
# Refcounts are also probably not safe in some situations (e.g. back called # Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator) # from within a backpropagator)
function back!(x::Tracked, Δ) function back!(x, Δ)
istracked(x) || return
scan(x) scan(x)
back(x, Δ) back(tracker(x), Δ)
return
end end
back!(x, Δ) = back!(tracker(x), Δ)
# Out-of-place gradients # Out-of-place gradients
struct Params struct Params
@ -91,7 +92,7 @@ function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ) Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) || (Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))") error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, Δ) -> istracked(x) && back(g, x, Δ), c.args, Δs) foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
end end
back_(g::Grads, ::Call{Void}, Δ) = nothing back_(g::Grads, ::Call{Void}, Δ) = nothing
@ -108,8 +109,7 @@ function back(g::Grads, x::Tracked, Δ)
return return
end end
back(g::Grads, x, Δ) = back(g, tracker(x), Δ) back(::Grads, ::Void, _) = return
back(g::Grads, x::Void, Δ) = error("Can't backpropagate through `nothing`")
function forward(f, ps::Params) function forward(f, ps::Params)
y = f() y = f()
@ -117,7 +117,7 @@ function forward(f, ps::Params)
g = Grads() g = Grads()
if istracked(y) if istracked(y)
scan(y) scan(y)
back(g, y, Δ) back(g, tracker(y), Δ)
end end
for p in ps for p in ps
haskey(g, tracker(p)) || haskey(g, tracker(p)) ||

View File

@ -1,12 +1,14 @@
struct TrackedReal{T<:Real} <: Real struct TrackedReal{T<:Real} <: Real
data::T
tracker::Tracked{T} tracker::Tracked{T}
end end
TrackedReal(x::Real) = TrackedReal(Tracked(Call(), x, zero(x))) TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x)))
data(x::TrackedReal) = x.data
tracker(x::TrackedReal) = x.tracker tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x))) track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
function back!(x::TrackedReal) function back!(x::TrackedReal)
isinf(x) && error("Loss is Inf") isinf(x) && error("Loss is Inf")
@ -73,6 +75,7 @@ import Base:^
# Tuples # Tuples
struct TrackedTuple{T<:Tuple} struct TrackedTuple{T<:Tuple}
data::T
tracker::Tracked{T} tracker::Tracked{T}
end end
@ -82,7 +85,7 @@ accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
init_grad(x::Tuple) = init_grad.(x) init_grad(x::Tuple) = init_grad.(x)
zero_grad!(x::Tuple) = zero_grad!.(x) zero_grad!(x::Tuple) = zero_grad!.(x)
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs)) track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f))
function Base.show(io::IO, xs::TrackedTuple) function Base.show(io::IO, xs::TrackedTuple)
show(io, data(xs)) show(io, data(xs))
@ -100,7 +103,7 @@ back(::typeof(getindex), Δ, t, i) =
function collect(xs) function collect(xs)
xs = Base.collect(xs) xs = Base.collect(xs)
track(Call(collect, (xs,)), data.(xs)) track(Call(collect, (tracker.(xs),)), data.(xs))
end end
function scan(c::Call{typeof(collect)}) function scan(c::Call{typeof(collect)})
@ -108,5 +111,5 @@ function scan(c::Call{typeof(collect)})
end end
function back_(c::Call{typeof(collect)}, Δ) function back_(c::Call{typeof(collect)}, Δ)
foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args[1], Δ) foreach(back, c.args[1], Δ)
end end