Compare commits
1 Commits
master
...
mji/comple
Author | SHA1 | Date |
---|---|---|
![]() |
3eb5e07ded |
|
@ -70,6 +70,7 @@ include("idset.jl")
|
|||
include("back.jl")
|
||||
include("numeric.jl")
|
||||
include("lib/real.jl")
|
||||
include("lib/complex.jl")
|
||||
include("lib/array.jl")
|
||||
|
||||
"""
|
||||
|
|
|
@ -14,10 +14,7 @@ function scan(x::Tracked)
|
|||
return
|
||||
end
|
||||
|
||||
function scan(x)
|
||||
istracked(x) && scan(tracker(x))
|
||||
return
|
||||
end
|
||||
scan(::Nothing) = return
|
||||
|
||||
function back_(c::Call, Δ, once)
|
||||
Δs = c.func(Δ)
|
||||
|
@ -61,7 +58,7 @@ back(::Nothing, Δ, once) = return
|
|||
|
||||
function back!(x, Δ; once = true)
|
||||
istracked(x) || return
|
||||
scan(x)
|
||||
scan(tracker(x))
|
||||
back(tracker(x), Δ, once)
|
||||
return
|
||||
end
|
||||
|
@ -143,16 +140,19 @@ function forward(f, ps::Params)
|
|||
y, function (Δ)
|
||||
g = Grads(ps)
|
||||
if istracked(y)
|
||||
scan(y)
|
||||
scan(tracker(y))
|
||||
back(g, tracker(y), Δ)
|
||||
end
|
||||
return g
|
||||
end
|
||||
end
|
||||
|
||||
# Essentially a hack for complex numbers
|
||||
unwrap(x) = x
|
||||
|
||||
function forward(f, args...)
|
||||
args = param.(args)
|
||||
y, back = forward(() -> f(args...), Params(args))
|
||||
y, back = forward(() -> f(unwrap.(args)...), Params(args))
|
||||
y, Δ -> getindex.(Ref(back(Δ)), args)
|
||||
end
|
||||
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# Internal interface
|
||||
|
||||
struct _TrackedComplex{T<:Real}
|
||||
data::Complex{T}
|
||||
tracker::Tracked{Complex{T}}
|
||||
end
|
||||
|
||||
_TrackedComplex(x::Complex) = _TrackedComplex(x, Tracked{typeof(x)}(Call(), zero(x)))
|
||||
|
||||
data(x::_TrackedComplex) = x.data
|
||||
tracker(x::_TrackedComplex) = x.tracker
|
||||
|
||||
Base.real(x::_TrackedComplex) = track(real, x)
|
||||
Base.imag(x::_TrackedComplex) = track(imag, x)
|
||||
|
||||
@grad real(x::_TrackedComplex) = real(data(x)), r̄ -> (r̄ + zero(r̄)*im,)
|
||||
@grad imag(x::_TrackedComplex) = imag(data(x)), ī -> (zero(ī) + ī*im,)
|
||||
|
||||
unwrap(x::_TrackedComplex) = real(x) + imag(x)*im
|
||||
|
||||
track(f::Call, x::Complex) =
|
||||
unwrap(_TrackedComplex(x, Tracked{typeof(x)}(f, zero(x))))
|
||||
|
||||
param(x::Complex) = _TrackedComplex(float(x))
|
||||
|
||||
# External interface
|
||||
|
||||
TrackedComplex{T<:Real} = Complex{TrackedReal{T}}
|
||||
|
||||
data(x::TrackedComplex) = data(real(x)) + data(imag(x))*im
|
||||
|
||||
tracker(x::TrackedComplex) =
|
||||
Tracked{typeof(data(x))}(Call(c -> (real(c), imag(c)),
|
||||
(tracker(real(x)),tracker(imag(x)))),
|
||||
zero(data(x)))
|
||||
|
||||
function Base.show(io::IO, x::TrackedComplex)
|
||||
show(io, data(x))
|
||||
print(io, " (tracked)")
|
||||
end
|
||||
|
||||
Base.log(x::TrackedComplex) = track(log, x)
|
||||
|
||||
@grad log(x::TrackedComplex) = log(data(x)), ȳ -> (ȳ/x,)
|
|
@ -11,9 +11,8 @@ tracker(x::TrackedReal) = x.tracker
|
|||
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
|
||||
|
||||
function back!(x::TrackedReal; once = true)
|
||||
isinf(x) && error("Loss is Inf")
|
||||
isnan(x) && error("Loss is NaN")
|
||||
return back!(x, 1, once = once)
|
||||
losscheck(data(x))
|
||||
return back!(x, 1, once = once)
|
||||
end
|
||||
|
||||
function Base.show(io::IO, x::TrackedReal)
|
||||
|
@ -32,7 +31,7 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
|
|||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
||||
error("Not implemented: convert tracked $S to tracked $T")
|
||||
|
||||
for op in [:(==), :≈, :<]
|
||||
for op in [:(==), :≈, :<, :<=]
|
||||
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
|
||||
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
|
||||
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
|
||||
|
|
|
@ -291,4 +291,6 @@ end
|
|||
@test count == 3
|
||||
end
|
||||
|
||||
@test Tracker.gradient(x -> abs2(log(x)), 1+2im)[1] isa Complex
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue