TrackedNumber -> TrackedReal
This commit is contained in:
parent
d1c56ca768
commit
fc157a8c59
@ -44,7 +44,7 @@ include("scalar.jl")
|
|||||||
include("array.jl")
|
include("array.jl")
|
||||||
include("numeric.jl")
|
include("numeric.jl")
|
||||||
|
|
||||||
param(x::Number) = TrackedNumber(float(x))
|
param(x::Number) = TrackedReal(float(x))
|
||||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||||
|
|
||||||
using DataFlow
|
using DataFlow
|
||||||
|
@ -237,7 +237,7 @@ end
|
|||||||
|
|
||||||
dualify(xs, n) = xs
|
dualify(xs, n) = xs
|
||||||
dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
|
dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
|
||||||
dualify(xs::TrackedNumber, ps) = Dual(data(xs), ps)
|
dualify(xs::TrackedReal, ps) = Dual(data(xs), ps)
|
||||||
|
|
||||||
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
||||||
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||||
|
@ -1,51 +1,46 @@
|
|||||||
struct TrackedNumber{T<:Number} <: Number
|
struct TrackedReal{T<:Real} <: Real
|
||||||
tracker::Tracked{T}
|
tracker::Tracked{T}
|
||||||
end
|
end
|
||||||
|
|
||||||
TrackedNumber(x::Number) = TrackedNumber(Tracked(Call(nothing), x, zero(x)))
|
TrackedReal(x::Real) = TrackedReal(Tracked(Call(nothing), x, zero(x)))
|
||||||
|
|
||||||
tracker(x::TrackedNumber) = x.tracker
|
tracker(x::TrackedReal) = x.tracker
|
||||||
|
|
||||||
track(f::Call, x::Number) = TrackedNumber(Tracked(f, x, zero(x)))
|
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x)))
|
||||||
|
|
||||||
back!(x::TrackedNumber) = back!(x, 1)
|
back!(x::TrackedReal) = back!(x, 1)
|
||||||
|
|
||||||
function Base.show(io::IO, x::TrackedNumber)
|
function Base.show(io::IO, x::TrackedReal)
|
||||||
show(io, data(x))
|
show(io, data(x))
|
||||||
print(io, " (tracked)")
|
print(io, " (tracked)")
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.convert(::Type{TrackedNumber{T}}, x::TrackedNumber{T}) where T = x
|
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{T}) where T = x
|
||||||
|
|
||||||
Base.convert(::Type{TrackedNumber{T}}, x::TrackedNumber) where T =
|
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal) where T =
|
||||||
TrackedNumber(Tracked(x.tracker.f, convert(T, x.tracker.data)))
|
TrackedReal(Tracked(x.tracker.f, convert(T, x.tracker.data)))
|
||||||
|
|
||||||
Base.convert(::Type{TrackedNumber{T}}, x::Number) where T = TrackedNumber(convert(T, x))
|
Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x))
|
||||||
|
|
||||||
Base.isless(x::TrackedNumber, y::Number) = isless(data(x), y)
|
Base.:(<)(x::TrackedReal, y::TrackedReal) = data(x) < data(y)
|
||||||
Base.isless(x::Number, y::TrackedNumber) = isless(x, data(y))
|
Base.:(==)(x::TrackedReal, y::TrackedReal) = data(x) == data(y)
|
||||||
Base.isless(x::TrackedNumber, y::TrackedNumber) = isless(data(x), data(y))
|
|
||||||
|
|
||||||
Base.:(==)(x::TrackedNumber, y::Number) = data(x) == y
|
|
||||||
Base.:(==)(x::Number, y::TrackedNumber) = x == data(y)
|
|
||||||
Base.:(==)(x::TrackedNumber, y::TrackedNumber) = data(x) == data(y)
|
|
||||||
|
|
||||||
for f in :[isinf, isnan, isfinite].args
|
for f in :[isinf, isnan, isfinite].args
|
||||||
@eval Base.$f(x::TrackedNumber) = Base.$f(data(x))
|
@eval Base.$f(x::TrackedReal) = Base.$f(data(x))
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.Printf.fix_dec(x::TrackedNumber, n::Int) = Base.Printf.fix_dec(data(x), n)
|
# Base.Printf.fix_dec(x::TrackedReal, n::Int) = Base.Printf.fix_dec(data(x), n)
|
||||||
|
|
||||||
Base.promote_rule(::Type{TrackedNumber{S}},::Type{T}) where {S,T} =
|
Base.promote_rule(::Type{TrackedReal{S}},::Type{T}) where {S,T} =
|
||||||
TrackedNumber{promote_type(S,T)}
|
TrackedReal{promote_type(S,T)}
|
||||||
|
|
||||||
using DiffRules, SpecialFunctions, NaNMath
|
using DiffRules, SpecialFunctions, NaNMath
|
||||||
|
|
||||||
for (M, f, arity) in DiffRules.diffrules()
|
for (M, f, arity) in DiffRules.diffrules()
|
||||||
arity == 1 || continue
|
arity == 1 || continue
|
||||||
@eval begin
|
@eval begin
|
||||||
$M.$f(a::TrackedNumber) = track($M.$f, a)
|
$M.$f(a::TrackedReal) = track($M.$f, a)
|
||||||
back(::typeof($M.$f), Δ::Number, a::TrackedNumber) =
|
back(::typeof($M.$f), Δ::Real, a::TrackedReal) =
|
||||||
back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a)))))
|
back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a)))))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -54,10 +49,10 @@ for (M, f, arity) in DiffRules.diffrules()
|
|||||||
arity == 2 || continue
|
arity == 2 || continue
|
||||||
da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b)))
|
da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b)))
|
||||||
@eval begin
|
@eval begin
|
||||||
$M.$f(a::TrackedNumber, b::TrackedNumber) = track($M.$f, a, b)
|
$M.$f(a::TrackedReal, b::TrackedReal) = track($M.$f, a, b)
|
||||||
$M.$f(a::TrackedNumber, b::Number) = track($M.$f, a, b)
|
$M.$f(a::TrackedReal, b::Real) = track($M.$f, a, b)
|
||||||
$M.$f(a::Number, b::TrackedNumber) = track($M.$f, a, b)
|
$M.$f(a::Real, b::TrackedReal) = track($M.$f, a, b)
|
||||||
function back(::typeof($M.$f), Δ::Number, a::Number, b::Number)
|
function back(::typeof($M.$f), Δ::Real, a::Real, b::Real)
|
||||||
@back(a, Δ * $da)
|
@back(a, Δ * $da)
|
||||||
@back(b, Δ * $db)
|
@back(b, Δ * $db)
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user