basic TrackedComplex
This commit is contained in:
parent
c71c610747
commit
3eb5e07ded
@ -70,6 +70,7 @@ include("idset.jl")
|
|||||||
include("back.jl")
|
include("back.jl")
|
||||||
include("numeric.jl")
|
include("numeric.jl")
|
||||||
include("lib/real.jl")
|
include("lib/real.jl")
|
||||||
|
include("lib/complex.jl")
|
||||||
include("lib/array.jl")
|
include("lib/array.jl")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -14,10 +14,7 @@ function scan(x::Tracked)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
function scan(x)
|
scan(::Nothing) = return
|
||||||
istracked(x) && scan(tracker(x))
|
|
||||||
return
|
|
||||||
end
|
|
||||||
|
|
||||||
function back_(c::Call, Δ, once)
|
function back_(c::Call, Δ, once)
|
||||||
Δs = c.func(Δ)
|
Δs = c.func(Δ)
|
||||||
@ -61,7 +58,7 @@ back(::Nothing, Δ, once) = return
|
|||||||
|
|
||||||
function back!(x, Δ; once = true)
|
function back!(x, Δ; once = true)
|
||||||
istracked(x) || return
|
istracked(x) || return
|
||||||
scan(x)
|
scan(tracker(x))
|
||||||
back(tracker(x), Δ, once)
|
back(tracker(x), Δ, once)
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
@ -143,16 +140,19 @@ function forward(f, ps::Params)
|
|||||||
y, function (Δ)
|
y, function (Δ)
|
||||||
g = Grads(ps)
|
g = Grads(ps)
|
||||||
if istracked(y)
|
if istracked(y)
|
||||||
scan(y)
|
scan(tracker(y))
|
||||||
back(g, tracker(y), Δ)
|
back(g, tracker(y), Δ)
|
||||||
end
|
end
|
||||||
return g
|
return g
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Essentially a hack for complex numbers
|
||||||
|
unwrap(x) = x
|
||||||
|
|
||||||
function forward(f, args...)
|
function forward(f, args...)
|
||||||
args = param.(args)
|
args = param.(args)
|
||||||
y, back = forward(() -> f(args...), Params(args))
|
y, back = forward(() -> f(unwrap.(args)...), Params(args))
|
||||||
y, Δ -> getindex.(Ref(back(Δ)), args)
|
y, Δ -> getindex.(Ref(back(Δ)), args)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
44
src/tracker/lib/complex.jl
Normal file
44
src/tracker/lib/complex.jl
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# Internal interface
|
||||||
|
|
||||||
|
struct _TrackedComplex{T<:Real}
|
||||||
|
data::Complex{T}
|
||||||
|
tracker::Tracked{Complex{T}}
|
||||||
|
end
|
||||||
|
|
||||||
|
_TrackedComplex(x::Complex) = _TrackedComplex(x, Tracked{typeof(x)}(Call(), zero(x)))
|
||||||
|
|
||||||
|
data(x::_TrackedComplex) = x.data
|
||||||
|
tracker(x::_TrackedComplex) = x.tracker
|
||||||
|
|
||||||
|
Base.real(x::_TrackedComplex) = track(real, x)
|
||||||
|
Base.imag(x::_TrackedComplex) = track(imag, x)
|
||||||
|
|
||||||
|
@grad real(x::_TrackedComplex) = real(data(x)), r̄ -> (r̄ + zero(r̄)*im,)
|
||||||
|
@grad imag(x::_TrackedComplex) = imag(data(x)), ī -> (zero(ī) + ī*im,)
|
||||||
|
|
||||||
|
unwrap(x::_TrackedComplex) = real(x) + imag(x)*im
|
||||||
|
|
||||||
|
track(f::Call, x::Complex) =
|
||||||
|
unwrap(_TrackedComplex(x, Tracked{typeof(x)}(f, zero(x))))
|
||||||
|
|
||||||
|
param(x::Complex) = _TrackedComplex(float(x))
|
||||||
|
|
||||||
|
# External interface
|
||||||
|
|
||||||
|
TrackedComplex{T<:Real} = Complex{TrackedReal{T}}
|
||||||
|
|
||||||
|
data(x::TrackedComplex) = data(real(x)) + data(imag(x))*im
|
||||||
|
|
||||||
|
tracker(x::TrackedComplex) =
|
||||||
|
Tracked{typeof(data(x))}(Call(c -> (real(c), imag(c)),
|
||||||
|
(tracker(real(x)),tracker(imag(x)))),
|
||||||
|
zero(data(x)))
|
||||||
|
|
||||||
|
function Base.show(io::IO, x::TrackedComplex)
|
||||||
|
show(io, data(x))
|
||||||
|
print(io, " (tracked)")
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.log(x::TrackedComplex) = track(log, x)
|
||||||
|
|
||||||
|
@grad log(x::TrackedComplex) = log(data(x)), ȳ -> (ȳ/x,)
|
@ -11,9 +11,8 @@ tracker(x::TrackedReal) = x.tracker
|
|||||||
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
|
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
|
||||||
|
|
||||||
function back!(x::TrackedReal; once = true)
|
function back!(x::TrackedReal; once = true)
|
||||||
isinf(x) && error("Loss is Inf")
|
losscheck(data(x))
|
||||||
isnan(x) && error("Loss is NaN")
|
return back!(x, 1, once = once)
|
||||||
return back!(x, 1, once = once)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, x::TrackedReal)
|
function Base.show(io::IO, x::TrackedReal)
|
||||||
@ -32,7 +31,7 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
|
|||||||
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
|
||||||
error("Not implemented: convert tracked $S to tracked $T")
|
error("Not implemented: convert tracked $S to tracked $T")
|
||||||
|
|
||||||
for op in [:(==), :≈, :<]
|
for op in [:(==), :≈, :<, :<=]
|
||||||
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
|
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
|
||||||
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
|
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
|
||||||
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
|
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))
|
||||||
|
@ -291,4 +291,6 @@ end
|
|||||||
@test count == 3
|
@test count == 3
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@test Tracker.gradient(x -> abs2(log(x)), 1+2im)[1] isa Complex
|
||||||
|
|
||||||
end #testset
|
end #testset
|
||||||
|
Loading…
Reference in New Issue
Block a user