Flux.jl/src/tracker/scalar.jl

113 lines
2.8 KiB
Julia
Raw Normal View History

2018-02-08 17:18:40 +00:00
struct TrackedReal{T<:Real} <: Real
2018-02-07 20:39:36 +00:00
tracker::Tracked{T}
end
2018-02-08 17:18:40 +00:00
TrackedReal(x::Real) = TrackedReal(Tracked(Call(nothing), x, zero(x)))
2018-02-07 20:39:36 +00:00
2018-02-08 17:18:40 +00:00
tracker(x::TrackedReal) = x.tracker
2018-02-07 20:39:36 +00:00
2018-02-08 17:18:40 +00:00
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
2018-02-07 20:39:36 +00:00
function back!(x::TrackedReal)
2018-06-26 14:09:21 +00:00
isinf(x) && error("Loss is Inf")
isnan(x) && error("Loss is NaN")
return back!(x, 1)
end
2018-02-07 20:39:36 +00:00
2018-02-08 17:18:40 +00:00
function Base.show(io::IO, x::TrackedReal)
2018-02-07 20:39:36 +00:00
show(io, data(x))
print(io, " (tracked)")
end
2018-03-06 13:49:05 +00:00
Base.decompose(x::TrackedReal) = Base.decompose(data(x))
2018-02-08 17:18:40 +00:00
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
2018-02-07 20:39:36 +00:00
2018-02-08 17:18:40 +00:00
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
2018-02-07 20:39:36 +00:00
2018-06-25 10:36:52 +00:00
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T")
2018-02-08 17:18:40 +00:00
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
2018-02-07 20:39:36 +00:00
2018-02-07 23:21:04 +00:00
for f in :[isinf, isnan, isfinite].args
2018-02-08 17:18:40 +00:00
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
2018-02-07 20:39:36 +00:00
end
2018-02-12 15:05:09 +00:00
Base.Printf.fix_dec(x::TrackedReal, n::Int) = Base.Printf.fix_dec(data(x), n)
2018-02-07 23:21:04 +00:00
2018-02-08 17:18:40 +00:00
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
TrackedReal{promote_type(S,T)}
2018-02-07 20:39:36 +00:00
using DiffRules, SpecialFunctions, NaNMath
for (M, f, arity) in DiffRules.diffrules()
arity == 1 || continue
@eval begin
2018-02-08 17:18:40 +00:00
$M.$f(a::TrackedReal) = track($M.$f, a)
back(::typeof($M.$f), Δ::Real, a::TrackedReal) =
2018-02-07 20:39:36 +00:00
back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a)))))
end
end
for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue
da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b)))
@eval begin
2018-02-08 17:18:40 +00:00
$M.$f(a::TrackedReal, b::TrackedReal) = track($M.$f, a, b)
$M.$f(a::TrackedReal, b::Real) = track($M.$f, a, b)
$M.$f(a::Real, b::TrackedReal) = track($M.$f, a, b)
function back(::typeof($M.$f), Δ::Real, a::Real, b::Real)
2018-02-07 20:39:36 +00:00
@back(a, Δ * $da)
@back(b, Δ * $db)
end
end
end
2018-02-07 22:20:44 +00:00
2018-03-13 02:50:56 +00:00
# Eliminating ambiguity
2018-03-12 14:48:16 +00:00
import Base:^
^(a::TrackedReal, b::Integer) = track(^, a, b)
2018-02-07 22:20:44 +00:00
# Tuples
struct TrackedTuple{T<:Tuple}
tracker::Tracked{T}
end
tracker(xs::TrackedTuple) = xs.tracker
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
init_grad(x::Tuple) = init_grad.(x)
2018-02-12 12:31:15 +00:00
zero_grad!(x::Tuple) = zero_grad!.(x)
2018-02-07 22:20:44 +00:00
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs))
function Base.show(io::IO, xs::TrackedTuple)
show(io, data(xs))
print(io, " (tracked)")
end
Base.length(x::TrackedTuple) = length(data(x))
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
back(::typeof(getindex), Δ, t, i) =
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
2018-06-06 16:01:28 +00:00
# Array collection
function collect(xs)
xs = Base.collect(xs)
track(Call(collect, xs), data.(xs))
end
function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1])
end
function back(::typeof(collect), Δ, xs)
foreach((x, Δ) -> @back(x, Δ), xs, Δ)
end