get tracker graph
This commit is contained in:
parent
49e1e78f67
commit
2a2475a9c2
@ -18,6 +18,7 @@ export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax,
|
||||
|
||||
include("tracker/Tracker.jl")
|
||||
using .Tracker
|
||||
export Tracker
|
||||
import .Tracker: data, value
|
||||
|
||||
include("optimise/Optimise.jl")
|
||||
|
@ -93,6 +93,28 @@ include("back.jl")
|
||||
include("lib.jl")
|
||||
include("numeric.jl")
|
||||
|
||||
using DataFlow
|
||||
using DataFlow: inputnode, constant
|
||||
|
||||
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
|
||||
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
|
||||
|
||||
function _graph(x::TrackedArray, inputs::TrackedArray...; cache = ObjectIdDict())
|
||||
haskey(cache, x) && return cache[x]
|
||||
i = findfirst(inputs, x)
|
||||
cache[x] =
|
||||
i > 0 ? inputnode(i) :
|
||||
isleaf(x) ? constant(x) :
|
||||
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...)
|
||||
end
|
||||
|
||||
_graph(x, inputs::TrackedArray...; cache = ObjectIdDict()) = constant(x)
|
||||
|
||||
function graph(f, args...)
|
||||
inputs = param.(args)
|
||||
_graph(f(inputs...), inputs...)
|
||||
end
|
||||
|
||||
import Adapt.adapt
|
||||
|
||||
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))
|
||||
|
@ -167,7 +167,8 @@ back_(::typeof(_pool), y, Δ, x, k, pad, mode) =
|
||||
|
||||
using ForwardDiff: Dual, partials
|
||||
|
||||
struct Broadcasted{T}
|
||||
struct Broadcasted{F,T}
|
||||
f::F
|
||||
data::T
|
||||
end
|
||||
|
||||
@ -180,9 +181,9 @@ function tracked_broadcast(f, args::Vararg{Any,N}) where N
|
||||
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
|
||||
out = broadcast(f, dargs...)
|
||||
eltype(out) <: Dual || return out
|
||||
# TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
|
||||
# TrackedArray(Call(Broadcasted(f, broadcast(f, dargs...)), args...))
|
||||
# Works around a 0.6 type inference issue
|
||||
b = Broadcasted(out)
|
||||
b = Broadcasted(f, out)
|
||||
TrackedArray(Call(b, args...), b())
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user