This commit is contained in:
Mike J Innes 2018-02-07 22:52:46 +00:00
parent 39f7f8fdf3
commit 0ac924e8e1
2 changed files with 13 additions and 7 deletions

View File

@ -3,8 +3,13 @@ DataFlow 0.2.1
Juno Juno
MacroTools 0.3.3 MacroTools 0.3.3
NNlib NNlib
ForwardDiff 0.5.0
Requires Requires
Adapt Adapt
GZip GZip
Colors Colors
# AD
ForwardDiff 0.5.0
DiffRules
SpecialFunctions
NaNMath

View File

@ -52,17 +52,18 @@ using DataFlow: inputnode, constant
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...) vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...) vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
function _graph(x::TrackedArray, inputs::TrackedArray...; cache = ObjectIdDict()) _graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...)
function _graph(x, inputs...; cache = ObjectIdDict())
haskey(cache, x) && return cache[x] haskey(cache, x) && return cache[x]
i = findfirst(inputs, x) i = findfirst(inputs, x)
cache[x] = cache[x] =
i > 0 ? inputnode(i) : i > 0 ? inputnode(i) :
isleaf(x) ? constant(x) : istracked(x) ? _graph(tracker(x), inputs...; cache = cache) :
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...) constant(x)
end end
_graph(x, inputs::TrackedArray...; cache = ObjectIdDict()) = constant(x)
function graph(f, args...) function graph(f, args...)
inputs = param.(args) inputs = param.(args)
_graph(f(inputs...), inputs...) _graph(f(inputs...), inputs...)
@ -70,6 +71,6 @@ end
import Adapt.adapt import Adapt.adapt
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad)) adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
end end