This commit is contained in:
Mike J Innes 2018-02-07 22:52:46 +00:00
parent 39f7f8fdf3
commit 0ac924e8e1
2 changed files with 13 additions and 7 deletions

View File

@ -3,8 +3,13 @@ DataFlow 0.2.1
Juno
MacroTools 0.3.3
NNlib
ForwardDiff 0.5.0
Requires
Adapt
GZip
Colors
# AD
ForwardDiff 0.5.0
DiffRules
SpecialFunctions
NaNMath

View File

@ -52,17 +52,18 @@ using DataFlow: inputnode, constant
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
function _graph(x::TrackedArray, inputs::TrackedArray...; cache = ObjectIdDict())
_graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...)
function _graph(x, inputs...; cache = ObjectIdDict())
haskey(cache, x) && return cache[x]
i = findfirst(inputs, x)
cache[x] =
i > 0 ? inputnode(i) :
isleaf(x) ? constant(x) :
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...)
istracked(x) ? _graph(tracker(x), inputs...; cache = cache) :
constant(x)
end
_graph(x, inputs::TrackedArray...; cache = ObjectIdDict()) = constant(x)
function graph(f, args...)
inputs = param.(args)
_graph(f(inputs...), inputs...)
@ -70,6 +71,6 @@ end
import Adapt.adapt
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
end