2018-02-08 17:18:40 +00:00
|
|
|
struct TrackedReal{T<:Real} <: Real
|
2018-07-09 18:44:14 +00:00
|
|
|
data::T
|
2018-02-07 20:39:36 +00:00
|
|
|
tracker::Tracked{T}
|
|
|
|
end
|
|
|
|
|
2018-07-09 18:44:14 +00:00
|
|
|
TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x)))
|
2018-02-07 20:39:36 +00:00
|
|
|
|
2018-07-09 18:44:14 +00:00
|
|
|
data(x::TrackedReal) = x.data
|
2018-02-08 17:18:40 +00:00
|
|
|
tracker(x::TrackedReal) = x.tracker
|
2018-02-07 20:39:36 +00:00
|
|
|
|
2018-07-09 18:44:14 +00:00
|
|
|
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
|
2018-02-07 20:39:36 +00:00
|
|
|
|
2018-05-07 22:30:44 +00:00
|
|
|
function back!(x::TrackedReal)
|
2018-06-26 14:09:21 +00:00
|
|
|
isinf(x) && error("Loss is Inf")
|
|
|
|
isnan(x) && error("Loss is NaN")
|
2018-05-07 22:30:44 +00:00
|
|
|
return back!(x, 1)
|
|
|
|
end
|
2018-02-07 20:39:36 +00:00
|
|
|
|
2018-02-08 17:18:40 +00:00
|
|
|
function Base.show(io::IO, x::TrackedReal)
|
2018-02-07 20:39:36 +00:00
|
|
|
show(io, data(x))
|
|
|
|
print(io, " (tracked)")
|
|
|
|
end
|
|
|
|
|
2018-03-06 13:49:05 +00:00
|
|
|
Base.decompose(x::TrackedReal) = Base.decompose(data(x))
|
|
|
|
|
2018-10-08 19:49:17 +00:00
|
|
|
Base.copy(x::TrackedReal) = x
|
2018-10-08 17:53:32 +00:00
|
|
|
|
2018-02-08 17:18:40 +00:00
|
|
|
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
2018-02-07 20:39:36 +00:00
|
|
|
|
2018-02-08 17:18:40 +00:00
|
|
|
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
2018-02-07 20:39:36 +00:00
|
|
|
|
2018-06-25 10:36:52 +00:00
|
|
|
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
|
|
|
error("Not implemented: convert tracked $S to tracked $T")
|
|
|
|
|
2018-08-25 06:51:40 +00:00
|
|
|
for op in [:(==), :≈, :<]
|
2018-08-25 07:12:01 +00:00
|
|
|
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
|
|
|
|
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
|
2018-08-25 06:51:40 +00:00
|
|
|
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
|
|
|
|
end
|
2018-02-07 20:39:36 +00:00
|
|
|
|
2018-06-27 05:55:43 +00:00
|
|
|
Base.eps(x::TrackedReal) = eps(data(x))
|
|
|
|
|
2018-02-07 23:21:04 +00:00
|
|
|
for f in :[isinf, isnan, isfinite].args
|
2018-02-08 17:18:40 +00:00
|
|
|
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
|
2018-02-07 20:39:36 +00:00
|
|
|
end
|
|
|
|
|
2018-02-12 15:05:09 +00:00
|
|
|
Base.Printf.fix_dec(x::TrackedReal, n::Int) = Base.Printf.fix_dec(data(x), n)
|
2018-02-07 23:21:04 +00:00
|
|
|
|
2018-02-08 17:18:40 +00:00
|
|
|
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
|
|
|
|
TrackedReal{promote_type(S,T)}
|
2018-02-07 20:39:36 +00:00
|
|
|
|
|
|
|
using DiffRules, SpecialFunctions, NaNMath
|
|
|
|
|
|
|
|
for (M, f, arity) in DiffRules.diffrules()
|
|
|
|
arity == 1 || continue
|
|
|
|
@eval begin
|
2018-07-06 10:28:18 +00:00
|
|
|
@grad $M.$f(a::Real) =
|
2018-07-10 08:03:09 +00:00
|
|
|
$M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),)
|
2018-02-08 17:18:40 +00:00
|
|
|
$M.$f(a::TrackedReal) = track($M.$f, a)
|
2018-02-07 20:39:36 +00:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
for (M, f, arity) in DiffRules.diffrules()
|
|
|
|
arity == 2 || continue
|
2018-07-10 08:03:09 +00:00
|
|
|
da, db = DiffRules.diffrule(M, f, :a, :b)
|
2018-07-06 10:28:18 +00:00
|
|
|
f = :($M.$f)
|
2018-02-07 20:39:36 +00:00
|
|
|
@eval begin
|
2018-09-27 08:40:44 +00:00
|
|
|
@grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
|
|
|
|
@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, zero(b))
|
|
|
|
@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (zero(a), Δ * $db)
|
2018-07-06 10:28:18 +00:00
|
|
|
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
|
|
|
|
$f(a::TrackedReal, b::Real) = track($f, a, b)
|
|
|
|
$f(a::Real, b::TrackedReal) = track($f, a, b)
|
2018-02-07 20:39:36 +00:00
|
|
|
end
|
|
|
|
end
|
2018-02-07 22:20:44 +00:00
|
|
|
|
2018-03-13 02:50:56 +00:00
|
|
|
# Eliminating ambiguity
|
2018-03-12 14:48:16 +00:00
|
|
|
import Base:^
|
|
|
|
|
|
|
|
^(a::TrackedReal, b::Integer) = track(^, a, b)
|
|
|
|
|
2018-02-07 22:20:44 +00:00
|
|
|
# Tuples
|
|
|
|
|
|
|
|
struct TrackedTuple{T<:Tuple}
|
2018-07-09 18:44:14 +00:00
|
|
|
data::T
|
2018-02-07 22:20:44 +00:00
|
|
|
tracker::Tracked{T}
|
|
|
|
end
|
|
|
|
|
2018-07-10 17:16:37 +00:00
|
|
|
data(xs::TrackedTuple) = xs.data
|
2018-02-07 22:20:44 +00:00
|
|
|
tracker(xs::TrackedTuple) = xs.tracker
|
|
|
|
|
|
|
|
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
|
|
|
init_grad(x::Tuple) = init_grad.(x)
|
2018-02-12 12:31:15 +00:00
|
|
|
zero_grad!(x::Tuple) = zero_grad!.(x)
|
2018-02-07 22:20:44 +00:00
|
|
|
|
2018-07-10 17:16:37 +00:00
|
|
|
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
|
2018-02-07 22:20:44 +00:00
|
|
|
|
|
|
|
function Base.show(io::IO, xs::TrackedTuple)
|
|
|
|
show(io, data(xs))
|
|
|
|
print(io, " (tracked)")
|
|
|
|
end
|
|
|
|
|
|
|
|
Base.length(x::TrackedTuple) = length(data(x))
|
|
|
|
|
|
|
|
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
|
|
|
|
2018-07-10 17:16:37 +00:00
|
|
|
@grad function getindex(xs::TrackedTuple, i)
|
|
|
|
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
|
|
|
|
end
|
2018-06-06 16:01:28 +00:00
|
|
|
|
|
|
|
# Array collection
|
|
|
|
|
|
|
|
function collect(xs)
|
|
|
|
xs = Base.collect(xs)
|
2018-07-09 18:44:14 +00:00
|
|
|
track(Call(collect, (tracker.(xs),)), data.(xs))
|
2018-06-06 16:01:28 +00:00
|
|
|
end
|
|
|
|
|
|
|
|
function scan(c::Call{typeof(collect)})
|
|
|
|
foreach(scan, c.args[1])
|
|
|
|
end
|
|
|
|
|
2018-07-09 12:39:10 +00:00
|
|
|
function back_(c::Call{typeof(collect)}, Δ)
|
2018-07-10 08:03:09 +00:00
|
|
|
foreach(back, c.args[1], data(Δ))
|
2018-06-06 16:01:28 +00:00
|
|
|
end
|
2018-08-07 21:09:20 +00:00
|
|
|
|
|
|
|
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
|
|
|
|
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
|
|
|
|
end
|