fixups
This commit is contained in:
parent
39f7f8fdf3
commit
0ac924e8e1
7
REQUIRE
7
REQUIRE
|
@ -3,8 +3,13 @@ DataFlow 0.2.1
|
|||
Juno
|
||||
MacroTools 0.3.3
|
||||
NNlib
|
||||
ForwardDiff 0.5.0
|
||||
Requires
|
||||
Adapt
|
||||
GZip
|
||||
Colors
|
||||
|
||||
# AD
|
||||
ForwardDiff 0.5.0
|
||||
DiffRules
|
||||
SpecialFunctions
|
||||
NaNMath
|
||||
|
|
|
@ -52,17 +52,18 @@ using DataFlow: inputnode, constant
|
|||
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
||||
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
||||
|
||||
function _graph(x::TrackedArray, inputs::TrackedArray...; cache = ObjectIdDict())
|
||||
_graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
|
||||
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...)
|
||||
|
||||
function _graph(x, inputs...; cache = ObjectIdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
i = findfirst(inputs, x)
|
||||
cache[x] =
|
||||
i > 0 ? inputnode(i) :
|
||||
isleaf(x) ? constant(x) :
|
||||
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...)
|
||||
istracked(x) ? _graph(tracker(x), inputs...; cache = cache) :
|
||||
constant(x)
|
||||
end
|
||||
|
||||
_graph(x, inputs::TrackedArray...; cache = ObjectIdDict()) = constant(x)
|
||||
|
||||
function graph(f, args...)
|
||||
inputs = param.(args)
|
||||
_graph(f(inputs...), inputs...)
|
||||
|
@ -70,6 +71,6 @@ end
|
|||
|
||||
import Adapt.adapt
|
||||
|
||||
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
|
||||
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
||||
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue