seperate tracking infrastructure from array wrapper
This commit is contained in:
parent
f9be72f545
commit
282889970d
|
@ -19,7 +19,7 @@ export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax,
|
|||
include("tracker/Tracker.jl")
|
||||
using .Tracker
|
||||
export Tracker
|
||||
import .Tracker: data, value
|
||||
import .Tracker: data
|
||||
|
||||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Juno
|
||||
using Flux.Tracker: back!, value
|
||||
using Flux.Tracker: back!
|
||||
|
||||
runall(f) = f
|
||||
runall(fs::AbstractVector) = () -> foreach(call, fs)
|
||||
|
@ -27,8 +27,8 @@ function train!(loss, data, opt; cb = () -> ())
|
|||
opt = runall(opt)
|
||||
@progress for d in data
|
||||
l = loss(d...)
|
||||
isinf(value(l)) && error("Loss is Inf")
|
||||
isnan(value(l)) && error("Loss is NaN")
|
||||
isinf(l.data[]) && error("Loss is Inf")
|
||||
isnan(l.data[]) && error("Loss is NaN")
|
||||
back!(l)
|
||||
opt()
|
||||
cb() == :stop && break
|
||||
|
|
|
@ -2,8 +2,12 @@ module Tracker
|
|||
|
||||
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
|
||||
|
||||
data(x) = x
|
||||
istracked(x) = false
|
||||
tracker(x) = nothing
|
||||
|
||||
istracked(x) = tracker(x) ≠ nothing
|
||||
isleaf(x) = !istracked(x) || isleaf(tracker(x))
|
||||
data(x) = istracked(x) ? data(tracker(x)) : x
|
||||
grad(x) = grad(tracker(x))
|
||||
|
||||
struct Call{F,As<:Tuple}
|
||||
func::F
|
||||
|
@ -14,85 +18,32 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
|
|||
|
||||
(c::Call)() = c.func(data.(c.args)...)
|
||||
|
||||
mutable struct TrackedArray{T,N,A} <: AbstractArray{T,N}
|
||||
mutable struct Tracked{T}
|
||||
ref::UInt32
|
||||
f::Call
|
||||
data::A
|
||||
grad::A
|
||||
TrackedArray{T,N,A}(f::Call, data::A) where {T,N,A} = new(0, f, data)
|
||||
TrackedArray{T,N,A}(f::Call, data::A, grad::A) where {T,N,A} = new(0, f, data, grad)
|
||||
data::T
|
||||
grad::T
|
||||
Tracked{T}(f::Call, data::T) where T = new(0, f, data)
|
||||
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, data, grad)
|
||||
end
|
||||
|
||||
TrackedScalar{T,A} = TrackedArray{T,0,A}
|
||||
TrackedVector{T,A} = TrackedArray{T,1,A}
|
||||
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
||||
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
||||
|
||||
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(c, x)
|
||||
|
||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
|
||||
|
||||
TrackedArray(c::Call) = TrackedArray(c, c())
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||
|
||||
isleaf(x::TrackedArray) = x.f == Call(nothing)
|
||||
|
||||
param(xs) = TrackedArray(map(x -> AbstractFloat(x), xs))
|
||||
param(xs::Real) = param(fill(xs))
|
||||
|
||||
istracked(x::TrackedArray) = true
|
||||
data(x::TrackedArray) = x.data
|
||||
grad(x::TrackedArray) = x.grad
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims].args
|
||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||
end
|
||||
|
||||
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||
similar(data(x), dims...)
|
||||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
# TODO decide if keeping both data and value. The problem is TrackedScalar
|
||||
value(x) = x
|
||||
value(x::TrackedArray) = data(x)
|
||||
value(x::TrackedScalar) = data(x)[]
|
||||
|
||||
Base.:(==)(x::TrackedArray, y) = value(x) == y
|
||||
Base.:(==)(y, x::TrackedArray) = y == value(x)
|
||||
Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(y)
|
||||
|
||||
Base.isless(x::TrackedScalar, y) = isless(value(x), y)
|
||||
Base.isless(x, y::TrackedScalar) = isless(x, value(y))
|
||||
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
|
||||
Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...)
|
||||
|
||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||
print(io, "TrackedArray{…,$A}")
|
||||
|
||||
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
||||
if repr
|
||||
print(io, "param(")
|
||||
Base.showarray(io, data(X), true)
|
||||
print(io, ")")
|
||||
else
|
||||
header && print(io, "Tracked ")
|
||||
Base.showarray(io, data(X), false, header = header)
|
||||
end
|
||||
end
|
||||
|
||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||
error("Can't differentiate `setindex!`")
|
||||
istracked(x::Tracked) = true
|
||||
isleaf(x::Tracked) = x.f == Call(nothing)
|
||||
data(x::Tracked) = x.data
|
||||
grad(x::Tracked) = x.grad
|
||||
|
||||
include("back.jl")
|
||||
include("lib.jl")
|
||||
include("array.jl")
|
||||
include("numeric.jl")
|
||||
|
||||
param(x::Number) = TrackedArray(fill(0))
|
||||
Base.isless(x::TrackedScalar, y) = isless(x.data[], y)
|
||||
Base.isless(x, y::TrackedScalar) = isless(x, y.data[])
|
||||
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(x.data[], y.data[])
|
||||
back!(x::TrackedScalar) = back!(x, 1)
|
||||
|
||||
param(xs::AbstractArray) = TrackedArray(map(x -> AbstractFloat(x), xs))
|
||||
|
||||
using DataFlow
|
||||
using DataFlow: inputnode, constant
|
||||
|
||||
|
|
|
@ -1,3 +1,65 @@
|
|||
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
||||
tracker::Tracked{A}
|
||||
data::A
|
||||
grad::A
|
||||
TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data)
|
||||
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
|
||||
end
|
||||
|
||||
tracker(x::TrackedArray) = x.tracker
|
||||
|
||||
TrackedScalar{T,A} = TrackedArray{T,0,A}
|
||||
TrackedVector{T,A} = TrackedArray{T,1,A}
|
||||
TrackedMatrix{T,A} = TrackedArray{T,2,A}
|
||||
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
|
||||
|
||||
TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x)
|
||||
|
||||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ)
|
||||
|
||||
TrackedArray(c::Call) = TrackedArray(c, c())
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||
|
||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||
print(io, "TrackedArray{…,$A}")
|
||||
|
||||
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
|
||||
if repr
|
||||
print(io, "param(")
|
||||
Base.showarray(io, data(X), true)
|
||||
print(io, ")")
|
||||
else
|
||||
header && print(io, "Tracked ")
|
||||
Base.showarray(io, data(X), false, header = header)
|
||||
end
|
||||
end
|
||||
|
||||
Base.setindex!(xs::TrackedArray, v, i...) =
|
||||
error("Can't differentiate `setindex!`")
|
||||
|
||||
# Fallthrough methods
|
||||
|
||||
for f in :[Base.size, Base.ndims].args
|
||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||
end
|
||||
|
||||
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||
similar(data(x), dims...)
|
||||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
value(x) = data(x)
|
||||
value(x::TrackedScalar) = data(x)[]
|
||||
|
||||
Base.:(==)(x::TrackedArray, y) = value(x) == y
|
||||
Base.:(==)(y, x::TrackedArray) = y == value(x)
|
||||
Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(y)
|
||||
|
||||
# Array Stdlib
|
||||
|
||||
toarray(xs::AbstractArray, ys::AbstractArray) = ys
|
||||
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
|
||||
|
||||
|
@ -60,7 +122,6 @@ back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
|
|||
|
||||
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
||||
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data)))
|
||||
Base.sum(xs::TrackedScalar, dim...) = xs
|
||||
|
||||
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
|
||||
|
|
@ -1,8 +1,6 @@
|
|||
scan(x) = nothing
|
||||
|
||||
scan(c::Call) = foreach(scan, c.args)
|
||||
|
||||
function scan(x::TrackedArray)
|
||||
function scan(x::Tracked)
|
||||
ref = x.ref += 1
|
||||
if ref == 1
|
||||
scan(x.f)
|
||||
|
@ -12,11 +10,16 @@ function scan(x::TrackedArray)
|
|||
return
|
||||
end
|
||||
|
||||
function scan(x)
|
||||
istracked(x) && scan(tracker(x))
|
||||
return
|
||||
end
|
||||
|
||||
back_(f, y, args...) = back(f, args...)
|
||||
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
|
||||
back_(::Call{Void}, y, Δ) = nothing
|
||||
|
||||
function back(x::TrackedArray, Δ)
|
||||
function back(x::Tracked, Δ)
|
||||
ref = x.ref -= 1
|
||||
if isdefined(x, :grad)
|
||||
x.grad .+= Δ
|
||||
|
@ -27,6 +30,8 @@ function back(x::TrackedArray, Δ)
|
|||
return
|
||||
end
|
||||
|
||||
back(x, Δ) = back(tracker(x), Δ)
|
||||
|
||||
macro back(x, Δ)
|
||||
quote
|
||||
x = $(esc(x))
|
||||
|
@ -39,9 +44,9 @@ end
|
|||
# TODO: if an error occurs in `back` the refcounts will be broken
|
||||
# and `back` will silently fail to update.
|
||||
|
||||
function back!(x::TrackedArray, Δ)
|
||||
function back!(x::Tracked, Δ)
|
||||
scan(x)
|
||||
back(x, Δ)
|
||||
end
|
||||
|
||||
back!(x::TrackedScalar) = back!(x, 1)
|
||||
back!(x, Δ) = back!(tracker(x), Δ)
|
||||
|
|
|
@ -13,7 +13,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
||||
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
||||
|
||||
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
|
||||
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
|
||||
|
||||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||
@test gradtest(x -> softmax(x).*(1:3), (3,5))
|
||||
|
|
Loading…
Reference in New Issue