get tracker graph

This commit is contained in:
Mike J Innes 2018-02-05 17:22:09 +00:00
parent 49e1e78f67
commit 2a2475a9c2
3 changed files with 27 additions and 3 deletions

View File

@ -18,6 +18,7 @@ export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax,
include("tracker/Tracker.jl")
using .Tracker
export Tracker
import .Tracker: data, value
include("optimise/Optimise.jl")

View File

@ -93,6 +93,28 @@ include("back.jl")
include("lib.jl")
include("numeric.jl")
using DataFlow
using DataFlow: inputnode, constant
vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...)
vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...)
function _graph(x::TrackedArray, inputs::TrackedArray...; cache = ObjectIdDict())
haskey(cache, x) && return cache[x]
i = findfirst(inputs, x)
cache[x] =
i > 0 ? inputnode(i) :
isleaf(x) ? constant(x) :
vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...)
end
_graph(x, inputs::TrackedArray...; cache = ObjectIdDict()) = constant(x)
function graph(f, args...)
inputs = param.(args)
_graph(f(inputs...), inputs...)
end
import Adapt.adapt
adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad))

View File

@ -167,7 +167,8 @@ back_(::typeof(_pool), y, Δ, x, k, pad, mode) =
using ForwardDiff: Dual, partials
struct Broadcasted{T}
struct Broadcasted{F,T}
f::F
data::T
end
@ -180,9 +181,9 @@ function tracked_broadcast(f, args::Vararg{Any,N}) where N
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
out = broadcast(f, dargs...)
eltype(out) <: Dual || return out
# TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...))
# TrackedArray(Call(Broadcasted(f, broadcast(f, dargs...)), args...))
# Works around a 0.6 type inference issue
b = Broadcasted(out)
b = Broadcasted(f, out)
TrackedArray(Call(b, args...), b())
end