TrackedNumber -> TrackedReal

This commit is contained in:
Mike J Innes 2018-02-08 17:18:40 +00:00
parent d1c56ca768
commit fc157a8c59
3 changed files with 24 additions and 29 deletions

View File

@ -44,7 +44,7 @@ include("scalar.jl")
include("array.jl") include("array.jl")
include("numeric.jl") include("numeric.jl")
param(x::Number) = TrackedNumber(float(x)) param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs)) param(xs::AbstractArray) = TrackedArray(float.(xs))
using DataFlow using DataFlow

View File

@ -237,7 +237,7 @@ end
dualify(xs, n) = xs dualify(xs, n) = xs
dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs)) dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
dualify(xs::TrackedNumber, ps) = Dual(data(xs), ps) dualify(xs::TrackedReal, ps) = Dual(data(xs), ps)
function tracked_broadcast(f, args::Vararg{Any,N}) where N function tracked_broadcast(f, args::Vararg{Any,N}) where N
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N})) dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))

View File

@ -1,51 +1,46 @@
struct TrackedNumber{T<:Number} <: Number struct TrackedReal{T<:Real} <: Real
tracker::Tracked{T} tracker::Tracked{T}
end end
TrackedNumber(x::Number) = TrackedNumber(Tracked(Call(nothing), x, zero(x))) TrackedReal(x::Real) = TrackedReal(Tracked(Call(nothing), x, zero(x)))
tracker(x::TrackedNumber) = x.tracker tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Number) = TrackedNumber(Tracked(f, x, zero(x))) track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
back!(x::TrackedNumber) = back!(x, 1) back!(x::TrackedReal) = back!(x, 1)
function Base.show(io::IO, x::TrackedNumber) function Base.show(io::IO, x::TrackedReal)
show(io, data(x)) show(io, data(x))
print(io, " (tracked)") print(io, " (tracked)")
end end
Base.convert(::Type{TrackedNumber{T}}, x::TrackedNumber{T}) where T = x Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
Base.convert(::Type{TrackedNumber{T}}, x::TrackedNumber) where T = Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
TrackedNumber(Tracked(x.tracker.f, convert(T, x.tracker.data))) TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
Base.convert(::Type{TrackedNumber{T}}, x::Number) where T = TrackedNumber(convert(T, x)) Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
Base.isless(x::TrackedNumber, y::Number) = isless(data(x), y) Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
Base.isless(x::Number, y::TrackedNumber) = isless(x, data(y)) Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
Base.isless(x::TrackedNumber, y::TrackedNumber) = isless(data(x), data(y))
Base.:(==)(x::TrackedNumber, y::Number) = data(x) == y
Base.:(==)(x::Number, y::TrackedNumber) = x == data(y)
Base.:(==)(x::TrackedNumber, y::TrackedNumber) = data(x) == data(y)
for f in :[isinf, isnan, isfinite].args for f in :[isinf, isnan, isfinite].args
@eval Base.$f(x::TrackedNumber) = Base.$f(data(x)) @eval Base.$f(x::TrackedReal) = Base.$f(data(x))
end end
Base.Printf.fix_dec(x::TrackedNumber, n::Int) = Base.Printf.fix_dec(data(x), n) # Base.Printf.fix_dec(x::TrackedReal, n::Int) = Base.Printf.fix_dec(data(x), n)
Base.promote_rule(::Type{TrackedNumber{S}},::Type{T}) where {S,T} = Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
TrackedNumber{promote_type(S,T)} TrackedReal{promote_type(S,T)}
using DiffRules, SpecialFunctions, NaNMath using DiffRules, SpecialFunctions, NaNMath
for (M, f, arity) in DiffRules.diffrules() for (M, f, arity) in DiffRules.diffrules()
arity == 1 || continue arity == 1 || continue
@eval begin @eval begin
$M.$f(a::TrackedNumber) = track($M.$f, a) $M.$f(a::TrackedReal) = track($M.$f, a)
back(::typeof($M.$f), Δ::Number, a::TrackedNumber) = back(::typeof($M.$f), Δ::Real, a::TrackedReal) =
back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a))))) back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a)))))
end end
end end
@ -54,10 +49,10 @@ for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue arity == 2 || continue
da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b))) da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b)))
@eval begin @eval begin
$M.$f(a::TrackedNumber, b::TrackedNumber) = track($M.$f, a, b) $M.$f(a::TrackedReal, b::TrackedReal) = track($M.$f, a, b)
$M.$f(a::TrackedNumber, b::Number) = track($M.$f, a, b) $M.$f(a::TrackedReal, b::Real) = track($M.$f, a, b)
$M.$f(a::Number, b::TrackedNumber) = track($M.$f, a, b) $M.$f(a::Real, b::TrackedReal) = track($M.$f, a, b)
function back(::typeof($M.$f), Δ::Number, a::Number, b::Number) function back(::typeof($M.$f), Δ::Real, a::Real, b::Real)
@back(a, Δ * $da) @back(a, Δ * $da)
@back(b, Δ * $db) @back(b, Δ * $db)
end end