diff --git a/src/Flux.jl b/src/Flux.jl index 0d1fbf63..2e124655 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -18,6 +18,7 @@ export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax, include("tracker/Tracker.jl") using .Tracker +export Tracker import .Tracker: data, value include("optimise/Optimise.jl") diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 5686c7a2..0b8ee3cd 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -93,6 +93,28 @@ include("back.jl") include("lib.jl") include("numeric.jl") +using DataFlow +using DataFlow: inputnode, constant + +vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...) +vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...) + +function _graph(x::TrackedArray, inputs::TrackedArray...; cache = ObjectIdDict()) + haskey(cache, x) && return cache[x] + i = findfirst(inputs, x) + cache[x] = + i > 0 ? inputnode(i) : + isleaf(x) ? constant(x) : + vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...) +end + +_graph(x, inputs::TrackedArray...; cache = ObjectIdDict()) = constant(x) + +function graph(f, args...) + inputs = param.(args) + _graph(f(inputs...), inputs...) +end + import Adapt.adapt adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad)) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index b8de5be1..e4fcdc4b 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -167,7 +167,8 @@ back_(::typeof(_pool), y, Δ, x, k, pad, mode) = using ForwardDiff: Dual, partials -struct Broadcasted{T} +struct Broadcasted{F,T} + f::F data::T end @@ -180,9 +181,9 @@ function tracked_broadcast(f, args::Vararg{Any,N}) where N dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N})) out = broadcast(f, dargs...) eltype(out) <: Dual || return out - # TrackedArray(Call(Broadcasted(broadcast(f, dargs...)), args...)) + # TrackedArray(Call(Broadcasted(f, broadcast(f, dargs...)), args...)) # Works around a 0.6 type inference issue - b = Broadcasted(out) + b = Broadcasted(f, out) TrackedArray(Call(b, args...), b()) end