Flux.jl/src/tracker/scalar.jl

125 lines
3.2 KiB
Julia
Raw Normal View History

2018-02-08 17:18:40 +00:00
struct TrackedReal{T<:Real} <: Real
2018-07-09 18:44:14 +00:00
data::T
2018-02-07 20:39:36 +00:00
tracker::Tracked{T}
end
2018-07-09 18:44:14 +00:00
TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x)))
2018-02-07 20:39:36 +00:00
2018-07-09 18:44:14 +00:00
data(x::TrackedReal) = x.data
2018-02-08 17:18:40 +00:00
tracker(x::TrackedReal) = x.tracker
2018-02-07 20:39:36 +00:00
2018-07-09 18:44:14 +00:00
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
2018-02-07 20:39:36 +00:00
function back!(x::TrackedReal)
2018-06-26 14:09:21 +00:00
isinf(x) && error("Loss is Inf")
isnan(x) && error("Loss is NaN")
return back!(x, 1)
end
2018-02-07 20:39:36 +00:00
2018-02-08 17:18:40 +00:00
function Base.show(io::IO, x::TrackedReal)
2018-02-07 20:39:36 +00:00
show(io, data(x))
print(io, " (tracked)")
end
2018-03-06 13:49:05 +00:00
Base.decompose(x::TrackedReal) = Base.decompose(data(x))
2018-02-08 17:18:40 +00:00
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
2018-02-07 20:39:36 +00:00
2018-02-08 17:18:40 +00:00
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
2018-02-07 20:39:36 +00:00
2018-06-25 10:36:52 +00:00
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T")
for op in [:(==), :≈, :<]
2018-08-25 07:12:01 +00:00
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
end
2018-02-07 20:39:36 +00:00
2018-06-27 05:55:43 +00:00
Base.eps(x::TrackedReal) = eps(data(x))
2018-02-07 23:21:04 +00:00
for f in :[isinf, isnan, isfinite].args
2018-02-08 17:18:40 +00:00
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
2018-02-07 20:39:36 +00:00
end
2018-02-12 15:05:09 +00:00
Base.Printf.fix_dec(x::TrackedReal, n::Int) = Base.Printf.fix_dec(data(x), n)
2018-02-07 23:21:04 +00:00
2018-02-08 17:18:40 +00:00
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
TrackedReal{promote_type(S,T)}
2018-02-07 20:39:36 +00:00
using DiffRules, SpecialFunctions, NaNMath
for (M, f, arity) in DiffRules.diffrules()
arity == 1 || continue
@eval begin
2018-07-06 10:28:18 +00:00
@grad $M.$f(a::Real) =
2018-07-10 08:03:09 +00:00
$M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),)
2018-02-08 17:18:40 +00:00
$M.$f(a::TrackedReal) = track($M.$f, a)
2018-02-07 20:39:36 +00:00
end
end
for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue
2018-07-10 08:03:09 +00:00
da, db = DiffRules.diffrule(M, f, :a, :b)
2018-07-06 10:28:18 +00:00
f = :($M.$f)
2018-02-07 20:39:36 +00:00
@eval begin
2018-07-10 08:03:09 +00:00
@grad $f(a::Real, b::Real) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
2018-07-06 10:28:18 +00:00
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
$f(a::TrackedReal, b::Real) = track($f, a, b)
$f(a::Real, b::TrackedReal) = track($f, a, b)
2018-02-07 20:39:36 +00:00
end
end
2018-02-07 22:20:44 +00:00
2018-03-13 02:50:56 +00:00
# Eliminating ambiguity
2018-03-12 14:48:16 +00:00
import Base:^
^(a::TrackedReal, b::Integer) = track(^, a, b)
2018-02-07 22:20:44 +00:00
# Tuples
struct TrackedTuple{T<:Tuple}
2018-07-09 18:44:14 +00:00
data::T
2018-02-07 22:20:44 +00:00
tracker::Tracked{T}
end
2018-07-10 17:16:37 +00:00
data(xs::TrackedTuple) = xs.data
2018-02-07 22:20:44 +00:00
tracker(xs::TrackedTuple) = xs.tracker
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
init_grad(x::Tuple) = init_grad.(x)
2018-02-12 12:31:15 +00:00
zero_grad!(x::Tuple) = zero_grad!.(x)
2018-02-07 22:20:44 +00:00
2018-07-10 17:16:37 +00:00
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
2018-02-07 22:20:44 +00:00
function Base.show(io::IO, xs::TrackedTuple)
show(io, data(xs))
print(io, " (tracked)")
end
Base.length(x::TrackedTuple) = length(data(x))
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
2018-07-10 17:16:37 +00:00
@grad function getindex(xs::TrackedTuple, i)
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
end
2018-06-06 16:01:28 +00:00
# Array collection
function collect(xs)
xs = Base.collect(xs)
2018-07-09 18:44:14 +00:00
track(Call(collect, (tracker.(xs),)), data.(xs))
2018-06-06 16:01:28 +00:00
end
function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1])
end
2018-07-09 12:39:10 +00:00
function back_(c::Call{typeof(collect)}, Δ)
2018-07-10 08:03:09 +00:00
foreach(back, c.args[1], data(Δ))
2018-06-06 16:01:28 +00:00
end
2018-08-07 21:09:20 +00:00
function back_(g::Grads, c::Call{typeof(collect)}, Δ)
foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ)
end