efficient traversal

This commit is contained in:
Mike J Innes 2017-09-06 23:09:32 -04:00
parent 7cfc42d166
commit cca4d25a10
6 changed files with 62 additions and 29 deletions

View File

@ -85,7 +85,7 @@ function LSTMCell(in, out; init = initn)
cell = LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]...,
Dense(in+out, out, tanh, init = initn),
track(initn(out)), track(initn(out)))
cell.forget.b.x .= 1
cell.forget.b.data .= 1
return cell
end

View File

@ -12,6 +12,6 @@ using Flux.Tracker: TrackedArray
params(ps, p::TrackedArray) = push!(ps, p)
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.x, x.Δ)
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[])
end

View File

@ -17,6 +17,7 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
(c::Call)() = c.func(data.(c.args)...)
struct TrackedArray{T,N,A} <: AbstractArray{T,N}
ref::RefValue{UInt32}
f::Call
data::A
grad::RefValue{A}
@ -28,7 +29,7 @@ TrackedMatrix{T,A} = TrackedArray{T,2,A}
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
TrackedArray(c::Call, x::A, Δ::Ref{A}) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
TrackedArray{eltype(A),ndims(A),A}(Ref(UInt32(0)), c, x, Δ)
TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, RefValue{typeof(x)}())

View File

@ -1,16 +1,43 @@
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
back!(::Call{Void}, Δ) = nothing
scan(x) = nothing
scan(c::Call) = foreach(scan, c.args)
function scan(x::TrackedArray)
ref = x.ref[] += 1
if ref == 1
scan(x.f)
else
isassigned(x.grad) || (x.grad[] = zeros(x.data))
end
return
end
back(c::Call, Δ) = back(c.func, Δ, c.args...)
back(::Call{Void}, Δ) = nothing
function back(x::TrackedArray, Δ)
ref = x.ref[] -= 1
if isassigned(x.grad)
x.grad[] .+= Δ
ref == 0 && back(x.f, x.grad[])
else
ref == 0 && back(x.f, Δ)
end
return
end
macro back(x, Δ)
quote
x = $(esc(x))
istracked(x) && back(x, $(esc(Δ)))
end
end
# Interface methods
function back!(x::TrackedArray, Δ)
isassigned(x.grad) && (x.grad[] .+= Δ)
back!(x.f, Δ)
scan(x)
back(x, Δ)
end
back!(x::TrackedScalar) = back!(x, 1)
macro back!(x, Δ)
quote
x = $(esc(x))
istracked(x) && back!(x, $(esc(Δ)))
end
end

View File

@ -9,21 +9,21 @@ unarray(xs::AbstractArray{T,0} where T) = xs[]
Base.getindex(xs::TrackedArray, i...) =
TrackedArray(Call(getindex, xs, i...), toarray(xs.data, xs.data[i...]))
function back!(::typeof(getindex), Δ, xs::TrackedArray, i...)
function back(::typeof(getindex), Δ, xs::TrackedArray, i...)
Δ′ = zeros(xs.data)
Δ′[i...] = unarray(Δ)
@back!(xs, Δ′)
@back(xs, Δ′)
end
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
back!(::typeof(-), Δ, xs::TrackedArray) = back!(xs, -Δ)
back(::typeof(-), Δ, xs::TrackedArray) = back(xs, -Δ)
Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs))
Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs))
back!(::typeof(transpose), Δ, xs) = @back!(xs, trim(xs, Δ.'))
back!(::typeof(ctranspose), Δ, xs) = @back!(xs, trim(xs, Δ'))
back(::typeof(transpose), Δ, xs) = @back(xs, trim(xs, Δ.'))
back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
Base.repmat(x::TrackedVecOrMat, a::Integer...) = TrackedArray(Call(repmat, x, a...))
Base.repmat(x::TrackedVecOrMat, a::Int64...) = TrackedArray(Call(repmat, x, a...))
@ -40,10 +40,10 @@ Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
function back!(::typeof(vcat), Δ, xs, ys)
function back(::typeof(vcat), Δ, xs, ys)
i = Base.tail(map(_ -> :, size(Δ)))
@back!(xs, Δ[1:size(xs,1), i...])
@back!(ys, Δ[size(xs,1)+1:end, i...])
@back(xs, Δ[1:size(xs,1), i...])
@back(ys, Δ[size(xs,1)+1:end, i...])
end
# Reductions
@ -52,7 +52,7 @@ Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data)))
Base.sum(xs::TrackedScalar, dim...) = xs
back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.data) .= Δ)
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
@ -67,9 +67,9 @@ a::TrackedMatrix * b::TrackedVector = TrackedArray(Call(*, a, b))
a::TrackedMatrix * b::AbstractVector = TrackedArray(Call(*, a, b))
a::AbstractMatrix * b::TrackedVector = TrackedArray(Call(*, a, b))
function back!(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
@back!(a, A_mul_Bt(Δ, data(b)))
@back!(b, At_mul_B(data(a), Δ))
function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
@back(a, A_mul_Bt(Δ, data(b)))
@back(b, At_mul_B(data(a), Δ))
end
# NNlib
@ -78,7 +78,7 @@ import NNlib: softmax, ∇softmax
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
back!(::typeof(softmax), Δ, xs) = @back!(xs, ∇softmax(Δ, data(xs)))
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
# Broadcasting
@ -112,9 +112,9 @@ function getpartial(Δ, x, i)
return Δ * p
end
function back!(b::Broadcasted, Δ, args::Vararg{Any,N}) where N
function back(b::Broadcasted, Δ, args::Vararg{Any,N}) where N
Δargs = ntuple(i -> getpartial.(Δ, b.data, i), Val{N})
foreach((x, Δ) -> @back!(x, unbroadcast(x, Δ)), args, Δargs)
foreach((x, Δ) -> @back(x, unbroadcast(x, Δ)), args, Δargs)
end
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray

View File

@ -22,4 +22,9 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(vcat, rand(5), rand(3))
@test gradtest(vcat, rand(2,3), rand(3,3))
@test gradtest(rand(5)) do x
y = x.^2
2y + x
end
end