shave some memory
This commit is contained in:
parent
1430053b69
commit
e763c342ee
@ -11,9 +11,9 @@ tracker(x) = nothing
|
|||||||
|
|
||||||
istracked(x) = tracker(x) ≠ nothing
|
istracked(x) = tracker(x) ≠ nothing
|
||||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||||
data(x) = istracked(x) ? data(tracker(x)) : x
|
|
||||||
grad(x) = grad(tracker(x))
|
grad(x) = grad(tracker(x))
|
||||||
grad(::Void) = nothing
|
grad(::Void) = nothing
|
||||||
|
data(x) = x
|
||||||
|
|
||||||
struct Call{F,As<:Tuple}
|
struct Call{F,As<:Tuple}
|
||||||
func::F
|
func::F
|
||||||
@ -32,29 +32,23 @@ mutable struct Tracked{T}
|
|||||||
ref::UInt32
|
ref::UInt32
|
||||||
f::Call
|
f::Call
|
||||||
isleaf::Bool
|
isleaf::Bool
|
||||||
data::T
|
|
||||||
grad::T
|
grad::T
|
||||||
Tracked{T}(f::Call, data::T) where T = new(0, f, false, data)
|
Tracked{T}(f::Call) where T = new(0, f, false)
|
||||||
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad)
|
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
|
||||||
Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad)
|
Tracked{T}(f::Call{Void}, grad::T) where T = new(0, f, true, grad)
|
||||||
end
|
end
|
||||||
|
|
||||||
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
|
|
||||||
Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ)
|
|
||||||
|
|
||||||
istracked(x::Tracked) = true
|
istracked(x::Tracked) = true
|
||||||
isleaf(x::Tracked) = x.f == Call()
|
isleaf(x::Tracked) = x.f == Call()
|
||||||
data(x::Tracked) = x.data
|
|
||||||
grad(x::Tracked) = x.grad
|
grad(x::Tracked) = x.grad
|
||||||
|
|
||||||
track(f::Call, x) = Tracked(f, x)
|
track(f::Call, x) = Tracked{typeof(x)}(f)
|
||||||
track(f::Call) = track(f, f())
|
|
||||||
|
|
||||||
function _forward end
|
function _forward end
|
||||||
|
|
||||||
function track(f, xs...; kw...)
|
function track(f, xs...; kw...)
|
||||||
y, back = _forward(f, data.(xs)...; kw...)
|
y, back = _forward(f, data.(xs)...; kw...)
|
||||||
track(Call(back, xs), y)
|
track(Call(back, tracker.(xs)), y)
|
||||||
end
|
end
|
||||||
|
|
||||||
macro grad(ex)
|
macro grad(ex)
|
||||||
|
@ -6,6 +6,7 @@ struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
|||||||
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
|
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
data(x::TrackedArray) = x.data
|
||||||
tracker(x::TrackedArray) = x.tracker
|
tracker(x::TrackedArray) = x.tracker
|
||||||
|
|
||||||
TrackedVector{T,A} = TrackedArray{T,1,A}
|
TrackedVector{T,A} = TrackedArray{T,1,A}
|
||||||
@ -15,10 +16,10 @@ TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
|||||||
track(c::Call, x::AbstractArray) = TrackedArray(c, x)
|
track(c::Call, x::AbstractArray) = TrackedArray(c, x)
|
||||||
|
|
||||||
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
||||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x)
|
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x)
|
||||||
|
|
||||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ)
|
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
|
||||||
|
|
||||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
|
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
|
||||||
|
|
||||||
@ -369,7 +370,7 @@ function ∇broadcast(f, args::Vararg{Any,N}) where N
|
|||||||
map((x, Δ) -> unbroadcast(x, Δ), args, Δargs)
|
map((x, Δ) -> unbroadcast(x, Δ), args, Δargs)
|
||||||
end
|
end
|
||||||
# So we can return non-tracked arrays
|
# So we can return non-tracked arrays
|
||||||
track(Call(back, args), y)
|
track(Call(back, tracker.(args)), y)
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray
|
Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray
|
||||||
|
@ -10,8 +10,6 @@ function scan(x::Tracked)
|
|||||||
if ref == 1
|
if ref == 1
|
||||||
scan(x.f)
|
scan(x.f)
|
||||||
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
|
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
|
||||||
else
|
|
||||||
isdefined(x, :grad) || (x.grad = init_grad(x.data))
|
|
||||||
end
|
end
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
@ -25,7 +23,7 @@ function back_(c::Call, Δ)
|
|||||||
Δs = c.func(Δ)
|
Δs = c.func(Δ)
|
||||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||||
error("Gradient is not a tuple of length $(length(c.args))")
|
error("Gradient is not a tuple of length $(length(c.args))")
|
||||||
foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args, Δs)
|
foreach(back, c.args, Δs)
|
||||||
end
|
end
|
||||||
|
|
||||||
back_(::Call{Void}, Δ) = nothing
|
back_(::Call{Void}, Δ) = nothing
|
||||||
@ -36,8 +34,12 @@ accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
|||||||
function back(x::Tracked, Δ)
|
function back(x::Tracked, Δ)
|
||||||
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
||||||
ref = x.ref -= 1
|
ref = x.ref -= 1
|
||||||
|
if ref > 0 || isdefined(x, :grad)
|
||||||
if isdefined(x, :grad)
|
if isdefined(x, :grad)
|
||||||
x.grad = accum!(x.grad, Δ)
|
x.grad = accum!(x.grad, Δ)
|
||||||
|
else
|
||||||
|
x.grad = Δ
|
||||||
|
end
|
||||||
ref == 0 && back_(x.f, x.grad)
|
ref == 0 && back_(x.f, x.grad)
|
||||||
else
|
else
|
||||||
ref == 0 && back_(x.f, Δ)
|
ref == 0 && back_(x.f, Δ)
|
||||||
@ -45,8 +47,7 @@ function back(x::Tracked, Δ)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
back(x, Δ) = back(tracker(x), Δ)
|
back(::Void, _) = return
|
||||||
back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
|
|
||||||
|
|
||||||
# Interface methods
|
# Interface methods
|
||||||
|
|
||||||
@ -55,13 +56,13 @@ back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
|
|||||||
# Refcounts are also probably not safe in some situations (e.g. back called
|
# Refcounts are also probably not safe in some situations (e.g. back called
|
||||||
# from within a backpropagator)
|
# from within a backpropagator)
|
||||||
|
|
||||||
function back!(x::Tracked, Δ)
|
function back!(x, Δ)
|
||||||
|
istracked(x) || return
|
||||||
scan(x)
|
scan(x)
|
||||||
back(x, Δ)
|
back(tracker(x), Δ)
|
||||||
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
back!(x, Δ) = back!(tracker(x), Δ)
|
|
||||||
|
|
||||||
# Out-of-place gradients
|
# Out-of-place gradients
|
||||||
|
|
||||||
struct Params
|
struct Params
|
||||||
@ -91,7 +92,7 @@ function back_(g::Grads, c::Call, Δ)
|
|||||||
Δs = c.func(Δ)
|
Δs = c.func(Δ)
|
||||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||||
error("Gradient is not a tuple of length $(length(c.args))")
|
error("Gradient is not a tuple of length $(length(c.args))")
|
||||||
foreach((x, Δ) -> istracked(x) && back(g, x, Δ), c.args, Δs)
|
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
||||||
end
|
end
|
||||||
|
|
||||||
back_(g::Grads, ::Call{Void}, Δ) = nothing
|
back_(g::Grads, ::Call{Void}, Δ) = nothing
|
||||||
@ -108,8 +109,7 @@ function back(g::Grads, x::Tracked, Δ)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
back(g::Grads, x, Δ) = back(g, tracker(x), Δ)
|
back(::Grads, ::Void, _) = return
|
||||||
back(g::Grads, x::Void, Δ) = error("Can't backpropagate through `nothing`")
|
|
||||||
|
|
||||||
function forward(f, ps::Params)
|
function forward(f, ps::Params)
|
||||||
y = f()
|
y = f()
|
||||||
@ -117,7 +117,7 @@ function forward(f, ps::Params)
|
|||||||
g = Grads()
|
g = Grads()
|
||||||
if istracked(y)
|
if istracked(y)
|
||||||
scan(y)
|
scan(y)
|
||||||
back(g, y, Δ)
|
back(g, tracker(y), Δ)
|
||||||
end
|
end
|
||||||
for p in ps
|
for p in ps
|
||||||
haskey(g, tracker(p)) ||
|
haskey(g, tracker(p)) ||
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
struct TrackedReal{T<:Real} <: Real
|
struct TrackedReal{T<:Real} <: Real
|
||||||
|
data::T
|
||||||
tracker::Tracked{T}
|
tracker::Tracked{T}
|
||||||
end
|
end
|
||||||
|
|
||||||
TrackedReal(x::Real) = TrackedReal(Tracked(Call(), x, zero(x)))
|
TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x)))
|
||||||
|
|
||||||
|
data(x::TrackedReal) = x.data
|
||||||
tracker(x::TrackedReal) = x.tracker
|
tracker(x::TrackedReal) = x.tracker
|
||||||
|
|
||||||
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
|
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
|
||||||
|
|
||||||
function back!(x::TrackedReal)
|
function back!(x::TrackedReal)
|
||||||
isinf(x) && error("Loss is Inf")
|
isinf(x) && error("Loss is Inf")
|
||||||
@ -73,6 +75,7 @@ import Base:^
|
|||||||
# Tuples
|
# Tuples
|
||||||
|
|
||||||
struct TrackedTuple{T<:Tuple}
|
struct TrackedTuple{T<:Tuple}
|
||||||
|
data::T
|
||||||
tracker::Tracked{T}
|
tracker::Tracked{T}
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -82,7 +85,7 @@ accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
|||||||
init_grad(x::Tuple) = init_grad.(x)
|
init_grad(x::Tuple) = init_grad.(x)
|
||||||
zero_grad!(x::Tuple) = zero_grad!.(x)
|
zero_grad!(x::Tuple) = zero_grad!.(x)
|
||||||
|
|
||||||
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs))
|
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f))
|
||||||
|
|
||||||
function Base.show(io::IO, xs::TrackedTuple)
|
function Base.show(io::IO, xs::TrackedTuple)
|
||||||
show(io, data(xs))
|
show(io, data(xs))
|
||||||
@ -100,7 +103,7 @@ back(::typeof(getindex), Δ, t, i) =
|
|||||||
|
|
||||||
function collect(xs)
|
function collect(xs)
|
||||||
xs = Base.collect(xs)
|
xs = Base.collect(xs)
|
||||||
track(Call(collect, (xs,)), data.(xs))
|
track(Call(collect, (tracker.(xs),)), data.(xs))
|
||||||
end
|
end
|
||||||
|
|
||||||
function scan(c::Call{typeof(collect)})
|
function scan(c::Call{typeof(collect)})
|
||||||
@ -108,5 +111,5 @@ function scan(c::Call{typeof(collect)})
|
|||||||
end
|
end
|
||||||
|
|
||||||
function back_(c::Call{typeof(collect)}, Δ)
|
function back_(c::Call{typeof(collect)}, Δ)
|
||||||
foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args[1], Δ)
|
foreach(back, c.args[1], Δ)
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user