Compare commits

...

1 Commits

Author SHA1 Message Date
Mike J Innes 3eb5e07ded basic TrackedComplex 2018-11-01 15:37:11 +00:00
5 changed files with 57 additions and 11 deletions

View File

@ -70,6 +70,7 @@ include("idset.jl")
include("back.jl")
include("numeric.jl")
include("lib/real.jl")
include("lib/complex.jl")
include("lib/array.jl")
"""

View File

@ -14,10 +14,7 @@ function scan(x::Tracked)
return
end
function scan(x)
istracked(x) && scan(tracker(x))
return
end
scan(::Nothing) = return
function back_(c::Call, Δ, once)
Δs = c.func(Δ)
@ -61,7 +58,7 @@ back(::Nothing, Δ, once) = return
function back!(x, Δ; once = true)
istracked(x) || return
scan(x)
scan(tracker(x))
back(tracker(x), Δ, once)
return
end
@ -143,16 +140,19 @@ function forward(f, ps::Params)
y, function (Δ)
g = Grads(ps)
if istracked(y)
scan(y)
scan(tracker(y))
back(g, tracker(y), Δ)
end
return g
end
end
# Essentially a hack for complex numbers
unwrap(x) = x
function forward(f, args...)
args = param.(args)
y, back = forward(() -> f(args...), Params(args))
y, back = forward(() -> f(unwrap.(args)...), Params(args))
y, Δ -> getindex.(Ref(back(Δ)), args)
end

View File

@ -0,0 +1,44 @@
# Internal interface
struct _TrackedComplex{T<:Real}
data::Complex{T}
tracker::Tracked{Complex{T}}
end
_TrackedComplex(x::Complex) = _TrackedComplex(x, Tracked{typeof(x)}(Call(), zero(x)))
data(x::_TrackedComplex) = x.data
tracker(x::_TrackedComplex) = x.tracker
Base.real(x::_TrackedComplex) = track(real, x)
Base.imag(x::_TrackedComplex) = track(imag, x)
@grad real(x::_TrackedComplex) = real(data(x)), -> ( + zero()*im,)
@grad imag(x::_TrackedComplex) = imag(data(x)), -> (zero() + *im,)
unwrap(x::_TrackedComplex) = real(x) + imag(x)*im
track(f::Call, x::Complex) =
unwrap(_TrackedComplex(x, Tracked{typeof(x)}(f, zero(x))))
param(x::Complex) = _TrackedComplex(float(x))
# External interface
TrackedComplex{T<:Real} = Complex{TrackedReal{T}}
data(x::TrackedComplex) = data(real(x)) + data(imag(x))*im
tracker(x::TrackedComplex) =
Tracked{typeof(data(x))}(Call(c -> (real(c), imag(c)),
(tracker(real(x)),tracker(imag(x)))),
zero(data(x)))
function Base.show(io::IO, x::TrackedComplex)
show(io, data(x))
print(io, " (tracked)")
end
Base.log(x::TrackedComplex) = track(log, x)
@grad log(x::TrackedComplex) = log(data(x)), -> (/x,)

View File

@ -11,9 +11,8 @@ tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
function back!(x::TrackedReal; once = true)
isinf(x) && error("Loss is Inf")
isnan(x) && error("Loss is NaN")
return back!(x, 1, once = once)
losscheck(data(x))
return back!(x, 1, once = once)
end
function Base.show(io::IO, x::TrackedReal)
@ -32,7 +31,7 @@ Base.convert(::Type{TrackedReal{T}}, x::Real) where T = TrackedReal(convert(T, x
Base.convert(::Type{TrackedReal{T}}, x::TrackedReal{S}) where {T,S} =
error("Not implemented: convert tracked $S to tracked $T")
for op in [:(==), :≈, :<]
for op in [:(==), :≈, :<, :<=]
@eval Base.$op(x::TrackedReal, y::Real) = Base.$op(data(x), y)
@eval Base.$op(x::Real, y::TrackedReal) = Base.$op(x, data(y))
@eval Base.$op(x::TrackedReal, y::TrackedReal) = Base.$op(data(x), data(y))

View File

@ -291,4 +291,6 @@ end
@test count == 3
end
@test Tracker.gradient(x -> abs2(log(x)), 1+2im)[1] isa Complex
end #testset