2017-08-19 10:06:19 +00:00
|
|
|
module Tracker
|
|
|
|
|
2017-11-09 14:53:26 +00:00
|
|
|
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
|
2017-08-19 15:20:53 +00:00
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
tracker(x) = nothing
|
|
|
|
|
|
|
|
istracked(x) = tracker(x) ≠ nothing
|
|
|
|
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
|
|
|
data(x) = istracked(x) ? data(tracker(x)) : x
|
|
|
|
grad(x) = grad(tracker(x))
|
2017-08-19 09:14:50 +00:00
|
|
|
|
2017-08-18 15:50:27 +00:00
|
|
|
struct Call{F,As<:Tuple}
|
|
|
|
func::F
|
|
|
|
args::As
|
|
|
|
end
|
|
|
|
|
|
|
|
Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
|
|
|
|
2017-08-19 09:14:50 +00:00
|
|
|
(c::Call)() = c.func(data.(c.args)...)
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
mutable struct Tracked{T}
|
2017-10-18 21:54:58 +00:00
|
|
|
ref::UInt32
|
2017-08-18 15:50:27 +00:00
|
|
|
f::Call
|
2018-02-07 17:43:25 +00:00
|
|
|
data::T
|
|
|
|
grad::T
|
|
|
|
Tracked{T}(f::Call, data::T) where T = new(0, f, data)
|
|
|
|
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, data, grad)
|
2017-08-19 09:14:50 +00:00
|
|
|
end
|
2017-08-19 10:00:55 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
|
|
|
|
|
|
|
|
track(f::Call, x) = Tracked(f, x)
|
|
|
|
track(f::Call) = track(f, f())
|
|
|
|
track(f, xs...) = track(Call(f, xs...))
|
|
|
|
|
2018-02-07 17:43:25 +00:00
|
|
|
istracked(x::Tracked) = true
|
|
|
|
isleaf(x::Tracked) = x.f == Call(nothing)
|
|
|
|
data(x::Tracked) = x.data
|
|
|
|
grad(x::Tracked) = x.grad
|
2017-10-26 10:15:14 +00:00
|
|
|
|
2017-09-07 01:21:35 +00:00
|
|
|
include("back.jl")
|
2018-02-07 20:39:36 +00:00
|
|
|
include("scalar.jl")
|
2018-02-07 17:43:25 +00:00
|
|
|
include("array.jl")
|
2017-08-23 00:43:45 +00:00
|
|
|
include("numeric.jl")
|
2017-08-19 15:02:19 +00:00
|
|
|
|
2018-02-07 20:39:36 +00:00
|
|
|
param(x::Number) = TrackedNumber(float(x))
|
|
|
|
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
2018-02-07 17:43:25 +00:00
|
|
|
|
2018-02-05 17:22:09 +00:00
|
|
|
using DataFlow
|
|
|
|
using DataFlow: inputnode, constant
|
|
|
|
|
|
|
|
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
|
|
|
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
|
|
|
|
|
|
|
function _graph(x::TrackedArray, inputs::TrackedArray...; cache = ObjectIdDict())
|
|
|
|
haskey(cache, x) && return cache[x]
|
|
|
|
i = findfirst(inputs, x)
|
|
|
|
cache[x] =
|
|
|
|
i > 0 ? inputnode(i) :
|
|
|
|
isleaf(x) ? constant(x) :
|
|
|
|
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...)
|
|
|
|
end
|
|
|
|
|
|
|
|
_graph(x, inputs::TrackedArray...; cache = ObjectIdDict()) = constant(x)
|
|
|
|
|
|
|
|
function graph(f, args...)
|
|
|
|
inputs = param.(args)
|
|
|
|
_graph(f(inputs...), inputs...)
|
|
|
|
end
|
|
|
|
|
2018-01-08 16:31:23 +00:00
|
|
|
import Adapt.adapt
|
2017-08-24 16:00:48 +00:00
|
|
|
|
2017-10-18 21:54:58 +00:00
|
|
|
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
|
2017-08-24 16:00:48 +00:00
|
|
|
|
2017-08-19 10:06:19 +00:00
|
|
|
end
|