efficient traversal
This commit is contained in:
parent
7cfc42d166
commit
cca4d25a10
@ -85,7 +85,7 @@ function LSTMCell(in, out; init = initn)
|
|||||||
cell = LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]...,
|
cell = LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]...,
|
||||||
Dense(in+out, out, tanh, init = initn),
|
Dense(in+out, out, tanh, init = initn),
|
||||||
track(initn(out)), track(initn(out)))
|
track(initn(out)), track(initn(out)))
|
||||||
cell.forget.b.x .= 1
|
cell.forget.b.data .= 1
|
||||||
return cell
|
return cell
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -12,6 +12,6 @@ using Flux.Tracker: TrackedArray
|
|||||||
|
|
||||||
params(ps, p::TrackedArray) = push!(ps, p)
|
params(ps, p::TrackedArray) = push!(ps, p)
|
||||||
|
|
||||||
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.x, x.Δ)
|
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[])
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -17,6 +17,7 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
|||||||
(c::Call)() = c.func(data.(c.args)...)
|
(c::Call)() = c.func(data.(c.args)...)
|
||||||
|
|
||||||
struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
||||||
|
ref::RefValue{UInt32}
|
||||||
f::Call
|
f::Call
|
||||||
data::A
|
data::A
|
||||||
grad::RefValue{A}
|
grad::RefValue{A}
|
||||||
@ -28,7 +29,7 @@ TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
|||||||
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
||||||
|
|
||||||
TrackedArray(c::Call, x::A, Δ::Ref{A}) where A <: AbstractArray =
|
TrackedArray(c::Call, x::A, Δ::Ref{A}) where A <: AbstractArray =
|
||||||
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
|
TrackedArray{eltype(A),ndims(A),A}(Ref(UInt32(0)), c, x, Δ)
|
||||||
|
|
||||||
TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, RefValue{typeof(x)}())
|
TrackedArray(c::Call, x::AbstractArray) = TrackedArray(c, x, RefValue{typeof(x)}())
|
||||||
|
|
||||||
|
@ -1,16 +1,43 @@
|
|||||||
back!(c::Call, Δ) = back!(c.func, Δ, c.args...)
|
scan(x) = nothing
|
||||||
back!(::Call{Void}, Δ) = nothing
|
|
||||||
|
scan(c::Call) = foreach(scan, c.args)
|
||||||
|
|
||||||
|
function scan(x::TrackedArray)
|
||||||
|
ref = x.ref[] += 1
|
||||||
|
if ref == 1
|
||||||
|
scan(x.f)
|
||||||
|
else
|
||||||
|
isassigned(x.grad) || (x.grad[] = zeros(x.data))
|
||||||
|
end
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
back(c::Call, Δ) = back(c.func, Δ, c.args...)
|
||||||
|
back(::Call{Void}, Δ) = nothing
|
||||||
|
|
||||||
|
function back(x::TrackedArray, Δ)
|
||||||
|
ref = x.ref[] -= 1
|
||||||
|
if isassigned(x.grad)
|
||||||
|
x.grad[] .+= Δ
|
||||||
|
ref == 0 && back(x.f, x.grad[])
|
||||||
|
else
|
||||||
|
ref == 0 && back(x.f, Δ)
|
||||||
|
end
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
macro back(x, Δ)
|
||||||
|
quote
|
||||||
|
x = $(esc(x))
|
||||||
|
istracked(x) && back(x, $(esc(Δ)))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# Interface methods
|
||||||
|
|
||||||
function back!(x::TrackedArray, Δ)
|
function back!(x::TrackedArray, Δ)
|
||||||
isassigned(x.grad) && (x.grad[] .+= Δ)
|
scan(x)
|
||||||
back!(x.f, Δ)
|
back(x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
back!(x::TrackedScalar) = back!(x, 1)
|
back!(x::TrackedScalar) = back!(x, 1)
|
||||||
|
|
||||||
macro back!(x, Δ)
|
|
||||||
quote
|
|
||||||
x = $(esc(x))
|
|
||||||
istracked(x) && back!(x, $(esc(Δ)))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
@ -9,21 +9,21 @@ unarray(xs::AbstractArray{T,0} where T) = xs[]
|
|||||||
Base.getindex(xs::TrackedArray, i...) =
|
Base.getindex(xs::TrackedArray, i...) =
|
||||||
TrackedArray(Call(getindex, xs, i...), toarray(xs.data, xs.data[i...]))
|
TrackedArray(Call(getindex, xs, i...), toarray(xs.data, xs.data[i...]))
|
||||||
|
|
||||||
function back!(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
function back(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
||||||
Δ′ = zeros(xs.data)
|
Δ′ = zeros(xs.data)
|
||||||
Δ′[i...] = unarray(Δ)
|
Δ′[i...] = unarray(Δ)
|
||||||
@back!(xs, Δ′)
|
@back(xs, Δ′)
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
|
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
|
||||||
|
|
||||||
back!(::typeof(-), Δ, xs::TrackedArray) = back!(xs, -Δ)
|
back(::typeof(-), Δ, xs::TrackedArray) = back(xs, -Δ)
|
||||||
|
|
||||||
Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs))
|
Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs))
|
||||||
Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs))
|
Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs))
|
||||||
|
|
||||||
back!(::typeof(transpose), Δ, xs) = @back!(xs, trim(xs, Δ.'))
|
back(::typeof(transpose), Δ, xs) = @back(xs, trim(xs, Δ.'))
|
||||||
back!(::typeof(ctranspose), Δ, xs) = @back!(xs, trim(xs, Δ'))
|
back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
|
||||||
|
|
||||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = TrackedArray(Call(repmat, x, a...))
|
Base.repmat(x::TrackedVecOrMat, a::Integer...) = TrackedArray(Call(repmat, x, a...))
|
||||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = TrackedArray(Call(repmat, x, a...))
|
Base.repmat(x::TrackedVecOrMat, a::Int64...) = TrackedArray(Call(repmat, x, a...))
|
||||||
@ -40,10 +40,10 @@ Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
|
|||||||
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call(vcat, a, b))
|
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call(vcat, a, b))
|
||||||
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
|
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
|
||||||
|
|
||||||
function back!(::typeof(vcat), Δ, xs, ys)
|
function back(::typeof(vcat), Δ, xs, ys)
|
||||||
i = Base.tail(map(_ -> :, size(Δ)))
|
i = Base.tail(map(_ -> :, size(Δ)))
|
||||||
@back!(xs, Δ[1:size(xs,1), i...])
|
@back(xs, Δ[1:size(xs,1), i...])
|
||||||
@back!(ys, Δ[size(xs,1)+1:end, i...])
|
@back(ys, Δ[size(xs,1)+1:end, i...])
|
||||||
end
|
end
|
||||||
|
|
||||||
# Reductions
|
# Reductions
|
||||||
@ -52,7 +52,7 @@ Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
|||||||
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data)))
|
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data)))
|
||||||
Base.sum(xs::TrackedScalar, dim...) = xs
|
Base.sum(xs::TrackedScalar, dim...) = xs
|
||||||
|
|
||||||
back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.data) .= Δ)
|
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
|
||||||
|
|
||||||
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
|
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
|
||||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||||
@ -67,9 +67,9 @@ a::TrackedMatrix * b::TrackedVector = TrackedArray(Call(*, a, b))
|
|||||||
a::TrackedMatrix * b::AbstractVector = TrackedArray(Call(*, a, b))
|
a::TrackedMatrix * b::AbstractVector = TrackedArray(Call(*, a, b))
|
||||||
a::AbstractMatrix * b::TrackedVector = TrackedArray(Call(*, a, b))
|
a::AbstractMatrix * b::TrackedVector = TrackedArray(Call(*, a, b))
|
||||||
|
|
||||||
function back!(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
|
function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
|
||||||
@back!(a, A_mul_Bt(Δ, data(b)))
|
@back(a, A_mul_Bt(Δ, data(b)))
|
||||||
@back!(b, At_mul_B(data(a), Δ))
|
@back(b, At_mul_B(data(a), Δ))
|
||||||
end
|
end
|
||||||
|
|
||||||
# NNlib
|
# NNlib
|
||||||
@ -78,7 +78,7 @@ import NNlib: softmax, ∇softmax
|
|||||||
|
|
||||||
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
|
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
|
||||||
|
|
||||||
back!(::typeof(softmax), Δ, xs) = @back!(xs, ∇softmax(Δ, data(xs)))
|
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
|
||||||
|
|
||||||
# Broadcasting
|
# Broadcasting
|
||||||
|
|
||||||
@ -112,9 +112,9 @@ function getpartial(Δ, x, i)
|
|||||||
return Δ * p
|
return Δ * p
|
||||||
end
|
end
|
||||||
|
|
||||||
function back!(b::Broadcasted, Δ, args::Vararg{Any,N}) where N
|
function back(b::Broadcasted, Δ, args::Vararg{Any,N}) where N
|
||||||
Δargs = ntuple(i -> getpartial.(Δ, b.data, i), Val{N})
|
Δargs = ntuple(i -> getpartial.(Δ, b.data, i), Val{N})
|
||||||
foreach((x, Δ) -> @back!(x, unbroadcast(x, Δ)), args, Δargs)
|
foreach((x, Δ) -> @back(x, unbroadcast(x, Δ)), args, Δargs)
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
|
Base.Broadcast._containertype(::Type{<:TrackedArray}) = TrackedArray
|
||||||
|
@ -22,4 +22,9 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||||||
@test gradtest(vcat, rand(5), rand(3))
|
@test gradtest(vcat, rand(5), rand(3))
|
||||||
@test gradtest(vcat, rand(2,3), rand(3,3))
|
@test gradtest(vcat, rand(2,3), rand(3,3))
|
||||||
|
|
||||||
|
@test gradtest(rand(5)) do x
|
||||||
|
y = x.^2
|
||||||
|
2y + x
|
||||||
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user