From cca4d25a10b55184d362db4ce70eb4349d92d04e Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 6 Sep 2017 23:09:32 -0400 Subject: [PATCH] efficient traversal --- src/layers/recurrent.jl | 2 +- src/optimise/Optimise.jl | 2 +- src/tracker/Tracker.jl | 3 ++- src/tracker/back.jl | 49 +++++++++++++++++++++++++++++++--------- src/tracker/lib.jl | 30 ++++++++++++------------ test/tracker.jl | 5 ++++ 6 files changed, 62 insertions(+), 29 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index f679ae37..a7b98129 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -85,7 +85,7 @@ function LSTMCell(in, out; init = initn) cell = LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]..., Dense(in+out, out, tanh, init = initn), track(initn(out)), track(initn(out))) - cell.forget.b.x .= 1 + cell.forget.b.data .= 1 return cell end diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 26fd2771..57c202eb 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -12,6 +12,6 @@ using Flux.Tracker: TrackedArray params(ps, p::TrackedArray) = push!(ps, p) -Base.convert(::Type{Param}, x::TrackedArray) = Param(x.x, x.Δ) +Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[]) end diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index ebc38f35..d0e33941 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -17,6 +17,7 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args) (c::Call)() = c.func(data.(c.args)...) struct TrackedArray{T,N,A} <: AbstractArray{T,N} + ref::RefValue{UInt32} f::Call data::A grad::RefValue{A} @@ -28,7 +29,7 @@ TrackedMatrix{T,A} = TrackedArray{T,2,A} TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}} TrackedArray(c::Call, x::A, Δ::Ref{A}) where A <: AbstractArray = - TrackedArray{eltype(A),ndims(A),A}(c, x, Δ) + TrackedArray{eltype(A),ndims(A),A}(Ref(UInt32(0)), c, x, Δ) TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, RefValue{typeof(x)}()) diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 42d70001..d11422ea 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -1,16 +1,43 @@ -back!(c::Call, Δ) = back!(c.func, Δ, c.args...) -back!(::Call{Void}, Δ) = nothing +scan(x) = nothing + +scan(c::Call) = foreach(scan, c.args) + +function scan(x::TrackedArray) + ref = x.ref[] += 1 + if ref == 1 + scan(x.f) + else + isassigned(x.grad) || (x.grad[] = zeros(x.data)) + end + return +end + +back(c::Call, Δ) = back(c.func, Δ, c.args...) +back(::Call{Void}, Δ) = nothing + +function back(x::TrackedArray, Δ) + ref = x.ref[] -= 1 + if isassigned(x.grad) + x.grad[] .+= Δ + ref == 0 && back(x.f, x.grad[]) + else + ref == 0 && back(x.f, Δ) + end + return +end + +macro back(x, Δ) + quote + x = $(esc(x)) + istracked(x) && back(x, $(esc(Δ))) + end +end + +# Interface methods function back!(x::TrackedArray, Δ) - isassigned(x.grad) && (x.grad[] .+= Δ) - back!(x.f, Δ) + scan(x) + back(x, Δ) end back!(x::TrackedScalar) = back!(x, 1) - -macro back!(x, Δ) - quote - x = $(esc(x)) - istracked(x) && back!(x, $(esc(Δ))) - end -end diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 71d2ed0f..254be8dc 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -9,21 +9,21 @@ unarray(xs::AbstractArray{T,0} where T) = xs[] Base.getindex(xs::TrackedArray, i...) = TrackedArray(Call(getindex, xs, i...), toarray(xs.data, xs.data[i...])) -function back!(::typeof(getindex), Δ, xs::TrackedArray, i...) +function back(::typeof(getindex), Δ, xs::TrackedArray, i...) Δ′ = zeros(xs.data) Δ′[i...] = unarray(Δ) - @back!(xs, Δ′) + @back(xs, Δ′) end Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs)) -back!(::typeof(-), Δ, xs::TrackedArray) = back!(xs, -Δ) +back(::typeof(-), Δ, xs::TrackedArray) = back(xs, -Δ) Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs)) Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs)) -back!(::typeof(transpose), Δ, xs) = @back!(xs, trim(xs, Δ.')) -back!(::typeof(ctranspose), Δ, xs) = @back!(xs, trim(xs, Δ')) +back(::typeof(transpose), Δ, xs) = @back(xs, trim(xs, Δ.')) +back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ')) Base.repmat(x::TrackedVecOrMat, a::Integer...) = TrackedArray(Call(repmat, x, a...)) Base.repmat(x::TrackedVecOrMat, a::Int64...) = TrackedArray(Call(repmat, x, a...)) @@ -40,10 +40,10 @@ Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b)) Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call(vcat, a, b)) Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b)) -function back!(::typeof(vcat), Δ, xs, ys) +function back(::typeof(vcat), Δ, xs, ys) i = Base.tail(map(_ -> :, size(Δ))) - @back!(xs, Δ[1:size(xs,1), i...]) - @back!(ys, Δ[size(xs,1)+1:end, i...]) + @back(xs, Δ[1:size(xs,1), i...]) + @back(ys, Δ[size(xs,1)+1:end, i...]) end # Reductions @@ -52,7 +52,7 @@ Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim)) Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data))) Base.sum(xs::TrackedScalar, dim...) = xs -back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.data) .= Δ) +back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ) Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) @@ -67,9 +67,9 @@ a::TrackedMatrix * b::TrackedVector = TrackedArray(Call(*, a, b)) a::TrackedMatrix * b::AbstractVector = TrackedArray(Call(*, a, b)) a::AbstractMatrix * b::TrackedVector = TrackedArray(Call(*, a, b)) -function back!(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat) - @back!(a, A_mul_Bt(Δ, data(b))) - @back!(b, At_mul_B(data(a), Δ)) +function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat) + @back(a, A_mul_Bt(Δ, data(b))) + @back(b, At_mul_B(data(a), Δ)) end # NNlib @@ -78,7 +78,7 @@ import NNlib: softmax, ∇softmax softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs)) -back!(::typeof(softmax), Δ, xs) = @back!(xs, ∇softmax(Δ, data(xs))) +back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs))) # Broadcasting @@ -112,9 +112,9 @@ function getpartial(Δ, x, i) return Δ * p end -function back!(b::Broadcasted, Δ, args::Vararg{Any,N}) where N +function back(b::Broadcasted, Δ, args::Vararg{Any,N}) where N Δargs = ntuple(i -> getpartial.(Δ, b.data, i), Val{N}) - foreach((x, Δ) -> @back!(x, unbroadcast(x, Δ)), args, Δargs) + foreach((x, Δ) -> @back(x, unbroadcast(x, Δ)), args, Δargs) end Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray diff --git a/test/tracker.jl b/test/tracker.jl index 0fa4598b..7c2b32ea 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -22,4 +22,9 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(vcat, rand(5), rand(3)) @test gradtest(vcat, rand(2,3), rand(3,3)) +@test gradtest(rand(5)) do x + y = x.^2 + 2y + x +end + end