2017-08-19 10:06:19 +00:00
|
|
|
module Tracker
|
|
|
|
|
2017-08-19 15:20:53 +00:00
|
|
|
export track, back!
|
|
|
|
|
2017-08-19 09:14:50 +00:00
|
|
|
data(x) = x
|
2017-08-19 16:40:07 +00:00
|
|
|
istracked(x) = false
|
2017-08-19 09:14:50 +00:00
|
|
|
|
2017-08-18 15:50:27 +00:00
|
|
|
struct Call{F,As<:Tuple}
|
|
|
|
func::F
|
|
|
|
args::As
|
|
|
|
end
|
|
|
|
|
|
|
|
Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
|
|
|
|
2017-08-19 09:14:50 +00:00
|
|
|
(c::Call)() = c.func(data.(c.args)...)
|
|
|
|
|
2017-08-18 15:50:27 +00:00
|
|
|
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
|
2017-08-19 16:40:07 +00:00
|
|
|
back!(::Call{Void}, Δ) = nothing
|
2017-08-19 09:14:50 +00:00
|
|
|
|
2017-08-19 10:11:25 +00:00
|
|
|
struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
2017-08-18 15:50:27 +00:00
|
|
|
f::Call
|
2017-08-19 09:14:50 +00:00
|
|
|
x::A
|
|
|
|
Δ::A
|
2017-08-18 15:50:27 +00:00
|
|
|
end
|
|
|
|
|
2017-08-19 10:11:25 +00:00
|
|
|
TrackedScalar{T,A} = TrackedArray{T,0,A}
|
|
|
|
TrackedVector{T,A} = TrackedArray{T,1,A}
|
|
|
|
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
2017-08-19 09:14:50 +00:00
|
|
|
|
2017-08-19 10:11:25 +00:00
|
|
|
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
|
|
|
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
|
2017-08-19 09:14:50 +00:00
|
|
|
|
2017-08-19 10:11:25 +00:00
|
|
|
TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, zeros(x))
|
2017-08-19 09:14:50 +00:00
|
|
|
|
2017-08-19 10:11:25 +00:00
|
|
|
TrackedArray(c::Call) = TrackedArray(c, c())
|
2017-08-19 09:14:50 +00:00
|
|
|
|
2017-08-19 10:11:25 +00:00
|
|
|
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x)
|
2017-08-19 09:14:50 +00:00
|
|
|
|
2017-08-19 10:11:25 +00:00
|
|
|
track(xs) = TrackedArray(xs)
|
2017-08-19 16:40:07 +00:00
|
|
|
istracked(x::TrackedArray) = true
|
2017-08-19 10:11:25 +00:00
|
|
|
data(x::TrackedArray) = x.x
|
|
|
|
grad(x::TrackedArray) = x.Δ
|
2017-08-18 15:50:27 +00:00
|
|
|
|
2017-08-23 16:50:43 +00:00
|
|
|
tovec(xs::AbstractArray) = vec(xs)
|
|
|
|
tovec(xs) = xs
|
|
|
|
|
2017-08-19 10:11:25 +00:00
|
|
|
function back!(x::TrackedArray, Δ)
|
2017-08-22 23:25:19 +00:00
|
|
|
Δ′ = vec(x.Δ)
|
2017-08-23 16:50:43 +00:00
|
|
|
Δ′ .+= tovec(Δ)
|
2017-08-18 15:50:27 +00:00
|
|
|
back!(x.f, Δ)
|
|
|
|
end
|
2017-08-19 09:14:50 +00:00
|
|
|
|
2017-08-23 16:50:49 +00:00
|
|
|
back!(x::TrackedScalar) = back!(x, 1)
|
|
|
|
|
2017-08-19 16:40:07 +00:00
|
|
|
macro back!(x, Δ)
|
|
|
|
quote
|
|
|
|
x = $(esc(x))
|
|
|
|
istracked(x) && back!(x, $(esc(Δ)))
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2017-08-19 15:02:19 +00:00
|
|
|
# Fallthrough methods
|
|
|
|
|
|
|
|
for f in :[Base.size, Base.ndims].args
|
2017-08-19 10:11:25 +00:00
|
|
|
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
2017-08-19 09:14:50 +00:00
|
|
|
end
|
2017-08-19 10:00:55 +00:00
|
|
|
|
2017-08-19 15:02:19 +00:00
|
|
|
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
|
|
|
similar(data(x), dims...)
|
|
|
|
|
|
|
|
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
|
|
|
|
2017-08-19 10:11:25 +00:00
|
|
|
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
2017-08-19 10:00:55 +00:00
|
|
|
if repr
|
2017-08-19 15:38:40 +00:00
|
|
|
print(io, "track(")
|
2017-08-19 10:00:55 +00:00
|
|
|
Base.showarray(io, data(X), true)
|
|
|
|
print(io, ")")
|
|
|
|
else
|
2017-08-19 15:38:40 +00:00
|
|
|
header && print(io, "Tracked ")
|
|
|
|
Base.showarray(io, data(X), false, header = header)
|
2017-08-19 10:00:55 +00:00
|
|
|
end
|
|
|
|
end
|
2017-08-19 10:06:19 +00:00
|
|
|
|
2017-08-19 15:02:19 +00:00
|
|
|
include("lib.jl")
|
2017-08-23 00:43:45 +00:00
|
|
|
include("numeric.jl")
|
2017-08-19 15:02:19 +00:00
|
|
|
|
2017-08-19 10:06:19 +00:00
|
|
|
end
|