seperate tracking infrastructure from array wrapper
This commit is contained in:
parent
f9be72f545
commit
282889970d
@ -19,7 +19,7 @@ export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax,
|
|||||||
include("tracker/Tracker.jl")
|
include("tracker/Tracker.jl")
|
||||||
using .Tracker
|
using .Tracker
|
||||||
export Tracker
|
export Tracker
|
||||||
import .Tracker: data, value
|
import .Tracker: data
|
||||||
|
|
||||||
include("optimise/Optimise.jl")
|
include("optimise/Optimise.jl")
|
||||||
using .Optimise
|
using .Optimise
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
using Juno
|
using Juno
|
||||||
using Flux.Tracker: back!, value
|
using Flux.Tracker: back!
|
||||||
|
|
||||||
runall(f) = f
|
runall(f) = f
|
||||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||||
@ -27,8 +27,8 @@ function train!(loss, data, opt; cb = () -> ())
|
|||||||
opt = runall(opt)
|
opt = runall(opt)
|
||||||
@progress for d in data
|
@progress for d in data
|
||||||
l = loss(d...)
|
l = loss(d...)
|
||||||
isinf(value(l)) && error("Loss is Inf")
|
isinf(l.data[]) && error("Loss is Inf")
|
||||||
isnan(value(l)) && error("Loss is NaN")
|
isnan(l.data[]) && error("Loss is NaN")
|
||||||
back!(l)
|
back!(l)
|
||||||
opt()
|
opt()
|
||||||
cb() == :stop && break
|
cb() == :stop && break
|
||||||
|
@ -2,8 +2,12 @@ module Tracker
|
|||||||
|
|
||||||
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
|
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
|
||||||
|
|
||||||
data(x) = x
|
tracker(x) = nothing
|
||||||
istracked(x) = false
|
|
||||||
|
istracked(x) = tracker(x) ≠ nothing
|
||||||
|
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||||
|
data(x) = istracked(x) ? data(tracker(x)) : x
|
||||||
|
grad(x) = grad(tracker(x))
|
||||||
|
|
||||||
struct Call{F,As<:Tuple}
|
struct Call{F,As<:Tuple}
|
||||||
func::F
|
func::F
|
||||||
@ -14,85 +18,32 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
|||||||
|
|
||||||
(c::Call)() = c.func(data.(c.args)...)
|
(c::Call)() = c.func(data.(c.args)...)
|
||||||
|
|
||||||
mutable struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
mutable struct Tracked{T}
|
||||||
ref::UInt32
|
ref::UInt32
|
||||||
f::Call
|
f::Call
|
||||||
data::A
|
data::T
|
||||||
grad::A
|
grad::T
|
||||||
TrackedArray{T,N,A}(f::Call, data::A) where {T,N,A} = new(0, f, data)
|
Tracked{T}(f::Call, data::T) where T = new(0, f, data)
|
||||||
TrackedArray{T,N,A}(f::Call, data::A, grad::A) where {T,N,A} = new(0, f, data, grad)
|
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, data, grad)
|
||||||
end
|
end
|
||||||
|
|
||||||
TrackedScalar{T,A} = TrackedArray{T,0,A}
|
istracked(x::Tracked) = true
|
||||||
TrackedVector{T,A} = TrackedArray{T,1,A}
|
isleaf(x::Tracked) = x.f == Call(nothing)
|
||||||
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
data(x::Tracked) = x.data
|
||||||
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
grad(x::Tracked) = x.grad
|
||||||
|
|
||||||
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
|
||||||
TrackedArray{eltype(A),ndims(A),A}(c, x)
|
|
||||||
|
|
||||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
|
||||||
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
|
|
||||||
|
|
||||||
TrackedArray(c::Call) = TrackedArray(c, c())
|
|
||||||
|
|
||||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
|
||||||
|
|
||||||
isleaf(x::TrackedArray) = x.f == Call(nothing)
|
|
||||||
|
|
||||||
param(xs) = TrackedArray(map(x -> AbstractFloat(x), xs))
|
|
||||||
param(xs::Real) = param(fill(xs))
|
|
||||||
|
|
||||||
istracked(x::TrackedArray) = true
|
|
||||||
data(x::TrackedArray) = x.data
|
|
||||||
grad(x::TrackedArray) = x.grad
|
|
||||||
|
|
||||||
# Fallthrough methods
|
|
||||||
|
|
||||||
for f in :[Base.size, Base.ndims].args
|
|
||||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
|
||||||
similar(data(x), dims...)
|
|
||||||
|
|
||||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
|
||||||
|
|
||||||
# TODO decide if keeping both data and value. The problem is TrackedScalar
|
|
||||||
value(x) = x
|
|
||||||
value(x::TrackedArray) = data(x)
|
|
||||||
value(x::TrackedScalar) = data(x)[]
|
|
||||||
|
|
||||||
Base.:(==)(x::TrackedArray, y) = value(x) == y
|
|
||||||
Base.:(==)(y, x::TrackedArray) = y == value(x)
|
|
||||||
Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(y)
|
|
||||||
|
|
||||||
Base.isless(x::TrackedScalar, y) = isless(value(x), y)
|
|
||||||
Base.isless(x, y::TrackedScalar) = isless(x, value(y))
|
|
||||||
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
|
|
||||||
Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...)
|
|
||||||
|
|
||||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
|
||||||
print(io, "TrackedArray{…,$A}")
|
|
||||||
|
|
||||||
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
|
||||||
if repr
|
|
||||||
print(io, "param(")
|
|
||||||
Base.showarray(io, data(X), true)
|
|
||||||
print(io, ")")
|
|
||||||
else
|
|
||||||
header && print(io, "Tracked ")
|
|
||||||
Base.showarray(io, data(X), false, header = header)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
|
||||||
error("Can't differentiate `setindex!`")
|
|
||||||
|
|
||||||
include("back.jl")
|
include("back.jl")
|
||||||
include("lib.jl")
|
include("array.jl")
|
||||||
include("numeric.jl")
|
include("numeric.jl")
|
||||||
|
|
||||||
|
param(x::Number) = TrackedArray(fill(0))
|
||||||
|
Base.isless(x::TrackedScalar, y) = isless(x.data[], y)
|
||||||
|
Base.isless(x, y::TrackedScalar) = isless(x, y.data[])
|
||||||
|
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(x.data[], y.data[])
|
||||||
|
back!(x::TrackedScalar) = back!(x, 1)
|
||||||
|
|
||||||
|
param(xs::AbstractArray) = TrackedArray(map(x -> AbstractFloat(x), xs))
|
||||||
|
|
||||||
using DataFlow
|
using DataFlow
|
||||||
using DataFlow: inputnode, constant
|
using DataFlow: inputnode, constant
|
||||||
|
|
||||||
|
@ -1,3 +1,65 @@
|
|||||||
|
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
||||||
|
tracker::Tracked{A}
|
||||||
|
data::A
|
||||||
|
grad::A
|
||||||
|
TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data)
|
||||||
|
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
|
||||||
|
end
|
||||||
|
|
||||||
|
tracker(x::TrackedArray) = x.tracker
|
||||||
|
|
||||||
|
TrackedScalar{T,A} = TrackedArray{T,0,A}
|
||||||
|
TrackedVector{T,A} = TrackedArray{T,1,A}
|
||||||
|
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
||||||
|
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
||||||
|
|
||||||
|
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
||||||
|
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x)
|
||||||
|
|
||||||
|
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||||
|
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ)
|
||||||
|
|
||||||
|
TrackedArray(c::Call) = TrackedArray(c, c())
|
||||||
|
|
||||||
|
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||||
|
|
||||||
|
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||||
|
print(io, "TrackedArray{…,$A}")
|
||||||
|
|
||||||
|
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
||||||
|
if repr
|
||||||
|
print(io, "param(")
|
||||||
|
Base.showarray(io, data(X), true)
|
||||||
|
print(io, ")")
|
||||||
|
else
|
||||||
|
header && print(io, "Tracked ")
|
||||||
|
Base.showarray(io, data(X), false, header = header)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||||
|
error("Can't differentiate `setindex!`")
|
||||||
|
|
||||||
|
# Fallthrough methods
|
||||||
|
|
||||||
|
for f in :[Base.size, Base.ndims].args
|
||||||
|
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||||
|
similar(data(x), dims...)
|
||||||
|
|
||||||
|
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||||
|
|
||||||
|
value(x) = data(x)
|
||||||
|
value(x::TrackedScalar) = data(x)[]
|
||||||
|
|
||||||
|
Base.:(==)(x::TrackedArray, y) = value(x) == y
|
||||||
|
Base.:(==)(y, x::TrackedArray) = y == value(x)
|
||||||
|
Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(y)
|
||||||
|
|
||||||
|
# Array Stdlib
|
||||||
|
|
||||||
toarray(xs::AbstractArray, ys::AbstractArray) = ys
|
toarray(xs::AbstractArray, ys::AbstractArray) = ys
|
||||||
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
|
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
|
||||||
|
|
||||||
@ -60,7 +122,6 @@ back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
|
|||||||
|
|
||||||
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
||||||
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data)))
|
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data)))
|
||||||
Base.sum(xs::TrackedScalar, dim...) = xs
|
|
||||||
|
|
||||||
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
|
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
|
||||||
|
|
@ -1,8 +1,6 @@
|
|||||||
scan(x) = nothing
|
|
||||||
|
|
||||||
scan(c::Call) = foreach(scan, c.args)
|
scan(c::Call) = foreach(scan, c.args)
|
||||||
|
|
||||||
function scan(x::TrackedArray)
|
function scan(x::Tracked)
|
||||||
ref = x.ref += 1
|
ref = x.ref += 1
|
||||||
if ref == 1
|
if ref == 1
|
||||||
scan(x.f)
|
scan(x.f)
|
||||||
@ -12,11 +10,16 @@ function scan(x::TrackedArray)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function scan(x)
|
||||||
|
istracked(x) && scan(tracker(x))
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
back_(f, y, args...) = back(f, args...)
|
back_(f, y, args...) = back(f, args...)
|
||||||
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
|
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
|
||||||
back_(::Call{Void}, y, Δ) = nothing
|
back_(::Call{Void}, y, Δ) = nothing
|
||||||
|
|
||||||
function back(x::TrackedArray, Δ)
|
function back(x::Tracked, Δ)
|
||||||
ref = x.ref -= 1
|
ref = x.ref -= 1
|
||||||
if isdefined(x, :grad)
|
if isdefined(x, :grad)
|
||||||
x.grad .+= Δ
|
x.grad .+= Δ
|
||||||
@ -27,6 +30,8 @@ function back(x::TrackedArray, Δ)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
back(x, Δ) = back(tracker(x), Δ)
|
||||||
|
|
||||||
macro back(x, Δ)
|
macro back(x, Δ)
|
||||||
quote
|
quote
|
||||||
x = $(esc(x))
|
x = $(esc(x))
|
||||||
@ -39,9 +44,9 @@ end
|
|||||||
# TODO: if an error occurs in `back` the refcounts will be broken
|
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||||
# and `back` will silently fail to update.
|
# and `back` will silently fail to update.
|
||||||
|
|
||||||
function back!(x::TrackedArray, Δ)
|
function back!(x::Tracked, Δ)
|
||||||
scan(x)
|
scan(x)
|
||||||
back(x, Δ)
|
back(x, Δ)
|
||||||
end
|
end
|
||||||
|
|
||||||
back!(x::TrackedScalar) = back!(x, 1)
|
back!(x, Δ) = back!(tracker(x), Δ)
|
||||||
|
@ -13,7 +13,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||||||
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
||||||
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
||||||
|
|
||||||
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
|
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
|
||||||
|
|
||||||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||||
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
||||||
|
Loading…
Reference in New Issue
Block a user