fixups
This commit is contained in:
parent
39f7f8fdf3
commit
0ac924e8e1
7
REQUIRE
7
REQUIRE
@ -3,8 +3,13 @@ DataFlow 0.2.1
|
|||||||
Juno
|
Juno
|
||||||
MacroTools 0.3.3
|
MacroTools 0.3.3
|
||||||
NNlib
|
NNlib
|
||||||
ForwardDiff 0.5.0
|
|
||||||
Requires
|
Requires
|
||||||
Adapt
|
Adapt
|
||||||
GZip
|
GZip
|
||||||
Colors
|
Colors
|
||||||
|
|
||||||
|
# AD
|
||||||
|
ForwardDiff 0.5.0
|
||||||
|
DiffRules
|
||||||
|
SpecialFunctions
|
||||||
|
NaNMath
|
||||||
|
@ -52,17 +52,18 @@ using DataFlow: inputnode, constant
|
|||||||
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
||||||
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
||||||
|
|
||||||
function _graph(x::TrackedArray, inputs::TrackedArray...; cache = ObjectIdDict())
|
_graph(x::Tracked, inputs...; cache = ObjectIdDict()) =
|
||||||
|
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...)
|
||||||
|
|
||||||
|
function _graph(x, inputs...; cache = ObjectIdDict())
|
||||||
haskey(cache, x) && return cache[x]
|
haskey(cache, x) && return cache[x]
|
||||||
i = findfirst(inputs, x)
|
i = findfirst(inputs, x)
|
||||||
cache[x] =
|
cache[x] =
|
||||||
i > 0 ? inputnode(i) :
|
i > 0 ? inputnode(i) :
|
||||||
isleaf(x) ? constant(x) :
|
istracked(x) ? _graph(tracker(x), inputs...; cache = cache) :
|
||||||
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...)
|
constant(x)
|
||||||
end
|
end
|
||||||
|
|
||||||
_graph(x, inputs::TrackedArray...; cache = ObjectIdDict()) = constant(x)
|
|
||||||
|
|
||||||
function graph(f, args...)
|
function graph(f, args...)
|
||||||
inputs = param.(args)
|
inputs = param.(args)
|
||||||
_graph(f(inputs...), inputs...)
|
_graph(f(inputs...), inputs...)
|
||||||
@ -70,6 +71,6 @@ end
|
|||||||
|
|
||||||
import Adapt.adapt
|
import Adapt.adapt
|
||||||
|
|
||||||
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
|
adapt(T, xs::TrackedArray) = param(adapt(T, data(xs)))
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user