shave some memory
This commit is contained in:
parent
1430053b69
commit
e763c342ee
@ -11,9 +11,9 @@ tracker(x) = nothing
|
||||
|
||||
istracked(x) = tracker(x) ≠ nothing
|
||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||
data(x) = istracked(x) ? data(tracker(x)) : x
|
||||
grad(x) = grad(tracker(x))
|
||||
grad(::Void) = nothing
|
||||
data(x) = x
|
||||
|
||||
struct Call{F,As<:Tuple}
|
||||
func::F
|
||||
@ -32,29 +32,23 @@ mutable struct Tracked{T}
|
||||
ref::UInt32
|
||||
f::Call
|
||||
isleaf::Bool
|
||||
data::T
|
||||
grad::T
|
||||
Tracked{T}(f::Call, data::T) where T = new(0, f, false, data)
|
||||
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad)
|
||||
Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad)
|
||||
Tracked{T}(f::Call) where T = new(0, f, false)
|
||||
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
|
||||
Tracked{T}(f::Call{Void}, grad::T) where T = new(0, f, true, grad)
|
||||
end
|
||||
|
||||
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
|
||||
Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ)
|
||||
|
||||
istracked(x::Tracked) = true
|
||||
isleaf(x::Tracked) = x.f == Call()
|
||||
data(x::Tracked) = x.data
|
||||
grad(x::Tracked) = x.grad
|
||||
|
||||
track(f::Call, x) = Tracked(f, x)
|
||||
track(f::Call) = track(f, f())
|
||||
track(f::Call, x) = Tracked{typeof(x)}(f)
|
||||
|
||||
function _forward end
|
||||
|
||||
function track(f, xs...; kw...)
|
||||
y, back = _forward(f, data.(xs)...; kw...)
|
||||
track(Call(back, xs), y)
|
||||
track(Call(back, tracker.(xs)), y)
|
||||
end
|
||||
|
||||
macro grad(ex)
|
||||
|
@ -6,6 +6,7 @@ struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
||||
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
|
||||
end
|
||||
|
||||
data(x::TrackedArray) = x.data
|
||||
tracker(x::TrackedArray) = x.tracker
|
||||
|
||||
TrackedVector{T,A} = TrackedArray{T,1,A}
|
||||
@ -15,10 +16,10 @@ TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
||||
track(c::Call, x::AbstractArray) = TrackedArray(c, x)
|
||||
|
||||
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x)
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x)
|
||||
|
||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ)
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
|
||||
|
||||
@ -369,7 +370,7 @@ function ∇broadcast(f, args::Vararg{Any,N}) where N
|
||||
map((x, Δ) -> unbroadcast(x, Δ), args, Δargs)
|
||||
end
|
||||
# So we can return non-tracked arrays
|
||||
track(Call(back, args), y)
|
||||
track(Call(back, tracker.(args)), y)
|
||||
end
|
||||
|
||||
Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray
|
||||
|
@ -10,8 +10,6 @@ function scan(x::Tracked)
|
||||
if ref == 1
|
||||
scan(x.f)
|
||||
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
|
||||
else
|
||||
isdefined(x, :grad) || (x.grad = init_grad(x.data))
|
||||
end
|
||||
return
|
||||
end
|
||||
@ -25,7 +23,7 @@ function back_(c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args, Δs)
|
||||
foreach(back, c.args, Δs)
|
||||
end
|
||||
|
||||
back_(::Call{Void}, Δ) = nothing
|
||||
@ -36,8 +34,12 @@ accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
function back(x::Tracked, Δ)
|
||||
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
if isdefined(x, :grad)
|
||||
x.grad = accum!(x.grad, Δ)
|
||||
if ref > 0 || isdefined(x, :grad)
|
||||
if isdefined(x, :grad)
|
||||
x.grad = accum!(x.grad, Δ)
|
||||
else
|
||||
x.grad = Δ
|
||||
end
|
||||
ref == 0 && back_(x.f, x.grad)
|
||||
else
|
||||
ref == 0 && back_(x.f, Δ)
|
||||
@ -45,8 +47,7 @@ function back(x::Tracked, Δ)
|
||||
return
|
||||
end
|
||||
|
||||
back(x, Δ) = back(tracker(x), Δ)
|
||||
back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
|
||||
back(::Void, _) = return
|
||||
|
||||
# Interface methods
|
||||
|
||||
@ -55,13 +56,13 @@ back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
|
||||
# Refcounts are also probably not safe in some situations (e.g. back called
|
||||
# from within a backpropagator)
|
||||
|
||||
function back!(x::Tracked, Δ)
|
||||
function back!(x, Δ)
|
||||
istracked(x) || return
|
||||
scan(x)
|
||||
back(x, Δ)
|
||||
back(tracker(x), Δ)
|
||||
return
|
||||
end
|
||||
|
||||
back!(x, Δ) = back!(tracker(x), Δ)
|
||||
|
||||
# Out-of-place gradients
|
||||
|
||||
struct Params
|
||||
@ -91,7 +92,7 @@ function back_(g::Grads, c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach((x, Δ) -> istracked(x) && back(g, x, Δ), c.args, Δs)
|
||||
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
||||
end
|
||||
|
||||
back_(g::Grads, ::Call{Void}, Δ) = nothing
|
||||
@ -108,8 +109,7 @@ function back(g::Grads, x::Tracked, Δ)
|
||||
return
|
||||
end
|
||||
|
||||
back(g::Grads, x, Δ) = back(g, tracker(x), Δ)
|
||||
back(g::Grads, x::Void, Δ) = error("Can't backpropagate through `nothing`")
|
||||
back(::Grads, ::Void, _) = return
|
||||
|
||||
function forward(f, ps::Params)
|
||||
y = f()
|
||||
@ -117,7 +117,7 @@ function forward(f, ps::Params)
|
||||
g = Grads()
|
||||
if istracked(y)
|
||||
scan(y)
|
||||
back(g, y, Δ)
|
||||
back(g, tracker(y), Δ)
|
||||
end
|
||||
for p in ps
|
||||
haskey(g, tracker(p)) ||
|
||||
|
@ -1,12 +1,14 @@
|
||||
struct TrackedReal{T<:Real} <: Real
|
||||
data::T
|
||||
tracker::Tracked{T}
|
||||
end
|
||||
|
||||
TrackedReal(x::Real) = TrackedReal(Tracked(Call(), x, zero(x)))
|
||||
TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x)))
|
||||
|
||||
data(x::TrackedReal) = x.data
|
||||
tracker(x::TrackedReal) = x.tracker
|
||||
|
||||
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
|
||||
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
|
||||
|
||||
function back!(x::TrackedReal)
|
||||
isinf(x) && error("Loss is Inf")
|
||||
@ -73,6 +75,7 @@ import Base:^
|
||||
# Tuples
|
||||
|
||||
struct TrackedTuple{T<:Tuple}
|
||||
data::T
|
||||
tracker::Tracked{T}
|
||||
end
|
||||
|
||||
@ -82,7 +85,7 @@ accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
||||
init_grad(x::Tuple) = init_grad.(x)
|
||||
zero_grad!(x::Tuple) = zero_grad!.(x)
|
||||
|
||||
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs))
|
||||
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f))
|
||||
|
||||
function Base.show(io::IO, xs::TrackedTuple)
|
||||
show(io, data(xs))
|
||||
@ -100,7 +103,7 @@ back(::typeof(getindex), Δ, t, i) =
|
||||
|
||||
function collect(xs)
|
||||
xs = Base.collect(xs)
|
||||
track(Call(collect, (xs,)), data.(xs))
|
||||
track(Call(collect, (tracker.(xs),)), data.(xs))
|
||||
end
|
||||
|
||||
function scan(c::Call{typeof(collect)})
|
||||
@ -108,5 +111,5 @@ function scan(c::Call{typeof(collect)})
|
||||
end
|
||||
|
||||
function back_(c::Call{typeof(collect)}, Δ)
|
||||
foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args[1], Δ)
|
||||
foreach(back, c.args[1], Δ)
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user