2017-08-19 10:06:19 +00:00
|
|
|
|
module Tracker
|
|
|
|
|
|
2018-07-06 10:28:18 +00:00
|
|
|
|
using MacroTools
|
2018-07-09 15:57:44 +00:00
|
|
|
|
using MacroTools: @q, @forward
|
2018-07-06 10:28:18 +00:00
|
|
|
|
|
2018-03-05 23:44:25 +00:00
|
|
|
|
import Base: ==
|
|
|
|
|
|
2018-07-11 14:31:22 +00:00
|
|
|
|
export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back!
|
2017-08-19 15:20:53 +00:00
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
|
tracker(x) = nothing
|
|
|
|
|
|
|
|
|
|
istracked(x) = tracker(x) ≠ nothing
|
|
|
|
|
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
|
|
|
|
grad(x) = grad(tracker(x))
|
2018-06-12 17:09:18 +00:00
|
|
|
|
grad(::Nothing) = nothing
|
2018-07-09 18:44:14 +00:00
|
|
|
|
data(x) = x
|
2017-08-19 09:14:50 +00:00
|
|
|
|
|
2017-08-18 15:50:27 +00:00
|
|
|
|
struct Call{F,As<:Tuple}
|
|
|
|
|
func::F
|
|
|
|
|
args::As
|
|
|
|
|
end
|
|
|
|
|
|
2018-07-30 19:08:44 +00:00
|
|
|
|
Call(f::F, args::T) where {F,T} = Call{F,T}(f, args)
|
2018-07-06 10:28:18 +00:00
|
|
|
|
Call() = Call(nothing, ())
|
2017-08-18 15:50:27 +00:00
|
|
|
|
|
2018-03-05 23:44:25 +00:00
|
|
|
|
# When deserialising, the object_id changes
|
|
|
|
|
a::Call == b::Call = a.func == b.func && a.args == b.args
|
|
|
|
|
|
2018-02-28 13:47:14 +00:00
|
|
|
|
@inline (c::Call)() = c.func(data.(c.args)...)
|
2017-08-19 09:14:50 +00:00
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
|
mutable struct Tracked{T}
|
2017-10-18 21:54:58 +00:00
|
|
|
|
ref::UInt32
|
2017-08-18 15:50:27 +00:00
|
|
|
|
f::Call
|
2018-02-12 12:31:15 +00:00
|
|
|
|
isleaf::Bool
|
2018-02-07 17:43:25 +00:00
|
|
|
|
grad::T
|
2018-07-09 18:44:14 +00:00
|
|
|
|
Tracked{T}(f::Call) where T = new(0, f, false)
|
|
|
|
|
Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
|
2018-06-12 17:09:18 +00:00
|
|
|
|
Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad)
|
2017-08-19 09:14:50 +00:00
|
|
|
|
end
|
2017-08-19 10:00:55 +00:00
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
|
istracked(x::Tracked) = true
|
2018-07-06 10:28:18 +00:00
|
|
|
|
isleaf(x::Tracked) = x.f == Call()
|
2018-02-07 17:43:25 +00:00
|
|
|
|
grad(x::Tracked) = x.grad
|
2017-10-26 10:15:14 +00:00
|
|
|
|
|
2018-07-09 18:44:14 +00:00
|
|
|
|
track(f::Call, x) = Tracked{typeof(x)}(f)
|
2018-07-06 10:28:18 +00:00
|
|
|
|
|
|
|
|
|
function _forward end
|
|
|
|
|
|
2018-07-30 19:08:44 +00:00
|
|
|
|
function track(f::F, xs...) where F
|
|
|
|
|
y, back = _forward(f, xs...)
|
|
|
|
|
ts = map(tracker, xs)
|
|
|
|
|
c = Call(back, ts)
|
|
|
|
|
track(c, y)
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function track_kw(f::F, xs...; kw...) where F
|
2018-07-10 08:03:09 +00:00
|
|
|
|
y, back = _forward(f, xs...; kw...)
|
2018-07-09 18:44:14 +00:00
|
|
|
|
track(Call(back, tracker.(xs)), y)
|
2018-07-06 10:28:18 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
macro grad(ex)
|
|
|
|
|
@capture(shortdef(ex), (name_(args__) = body_) |
|
|
|
|
|
(name_(args__) where {T__} = body_)) || error("Need a function definition")
|
|
|
|
|
T == nothing && (T = [])
|
2018-07-10 17:16:37 +00:00
|
|
|
|
isexpr(name, :(::)) || (name = :(::typeof($name)))
|
|
|
|
|
insert!(args, 1+isexpr(args[1], :parameters) , name)
|
2018-07-09 12:39:10 +00:00
|
|
|
|
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
|
2018-07-06 10:28:18 +00:00
|
|
|
|
end
|
|
|
|
|
|
2018-06-29 12:53:50 +00:00
|
|
|
|
function update!(x, Δ)
|
2018-07-11 14:31:22 +00:00
|
|
|
|
x.data .+= data(Δ)
|
2018-06-29 12:53:50 +00:00
|
|
|
|
tracker(x).grad .= 0
|
|
|
|
|
return x
|
|
|
|
|
end
|
|
|
|
|
|
2018-07-09 15:57:44 +00:00
|
|
|
|
include("idset.jl")
|
2017-09-07 01:21:35 +00:00
|
|
|
|
include("back.jl")
|
2018-02-07 20:39:36 +00:00
|
|
|
|
include("scalar.jl")
|
2018-02-07 17:43:25 +00:00
|
|
|
|
include("array.jl")
|
2017-08-23 00:43:45 +00:00
|
|
|
|
include("numeric.jl")
|
2017-08-19 15:02:19 +00:00
|
|
|
|
|
2018-07-02 12:17:46 +00:00
|
|
|
|
"""
|
|
|
|
|
hook(f, x) -> x′
|
|
|
|
|
|
|
|
|
|
Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
|
|
|
|
|
`f` will be applied to the incoming gradient. For example, `hook(-, x)` will reverse
|
|
|
|
|
the sign of the gradient applied to `x`.
|
|
|
|
|
"""
|
|
|
|
|
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
2018-07-09 12:39:10 +00:00
|
|
|
|
@grad hook(f, x) = x, Δ -> (nothing, f(Δ))
|
2018-07-02 12:17:46 +00:00
|
|
|
|
|
2018-07-09 16:52:34 +00:00
|
|
|
|
"""
|
|
|
|
|
checkpoint(f, args...)
|
|
|
|
|
|
|
|
|
|
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
|
|
|
|
|
calculating gradients. Instead, `f(args...)` will be called again during the
|
|
|
|
|
backward pass. This can be used to save memory in larger models.
|
|
|
|
|
"""
|
|
|
|
|
checkpoint(f, args...) = track(checkpoint, f, args...)
|
|
|
|
|
|
|
|
|
|
@grad function checkpoint(f, args...)
|
|
|
|
|
data(f(args...)), function (Δ)
|
|
|
|
|
y, back = forward(f, args...)
|
|
|
|
|
(nothing, back(Δ)...)
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
2018-07-10 08:03:09 +00:00
|
|
|
|
nobacksies(f, x) = track(nobacksies, f, x)
|
|
|
|
|
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
|
|
|
|
|
@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f")
|
|
|
|
|
|
2018-02-08 17:18:40 +00:00
|
|
|
|
param(x::Number) = TrackedReal(float(x))
|
2018-02-07 20:39:36 +00:00
|
|
|
|
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
2018-02-07 17:43:25 +00:00
|
|
|
|
|
2018-07-10 08:03:09 +00:00
|
|
|
|
@grad identity(x) = data(x), Δ -> (Δ,)
|
|
|
|
|
param(x::TrackedReal) = track(identity, x)
|
|
|
|
|
param(x::TrackedArray) = track(identity, x)
|
|
|
|
|
|
2018-03-01 16:31:20 +00:00
|
|
|
|
import NNlib.cudata
|
2018-01-08 16:31:23 +00:00
|
|
|
|
import Adapt.adapt
|
2017-08-24 16:00:48 +00:00
|
|
|
|
|
2018-03-01 16:31:20 +00:00
|
|
|
|
cudata(x::TrackedArray) = data(x)
|
2018-02-07 22:52:46 +00:00
|
|
|
|
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
2017-08-24 16:00:48 +00:00
|
|
|
|
|
2017-08-19 10:06:19 +00:00
|
|
|
|
end
|