seperate tracking infrastructure from array wrapper

This commit is contained in:
Mike J Innes 2018-02-07 17:43:25 +00:00
parent f9be72f545
commit 282889970d
6 changed files with 102 additions and 85 deletions

View File

@ -19,7 +19,7 @@ export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax,
include("tracker/Tracker.jl")
using .Tracker
export Tracker
import .Tracker: data, value
import .Tracker: data
include("optimise/Optimise.jl")
using .Optimise

View File

@ -1,5 +1,5 @@
using Juno
using Flux.Tracker: back!, value
using Flux.Tracker: back!
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)
@ -27,8 +27,8 @@ function train!(loss, data, opt; cb = () -> ())
opt = runall(opt)
@progress for d in data
l = loss(d...)
isinf(value(l)) && error("Loss is Inf")
isnan(value(l)) && error("Loss is NaN")
isinf(l.data[]) && error("Loss is Inf")
isnan(l.data[]) && error("Loss is NaN")
back!(l)
opt()
cb() == :stop && break

View File

@ -2,8 +2,12 @@ module Tracker
export TrackedArray, TrackedVector, TrackedMatrix, param, back!
data(x) = x
istracked(x) = false
tracker(x) = nothing
istracked(x) = tracker(x) nothing
isleaf(x) = !istracked(x) || isleaf(tracker(x))
data(x) = istracked(x) ? data(tracker(x)) : x
grad(x) = grad(tracker(x))
struct Call{F,As<:Tuple}
func::F
@ -14,85 +18,32 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args)
(c::Call)() = c.func(data.(c.args)...)
mutable struct TrackedArray{T,N,A} <: AbstractArray{T,N}
mutable struct Tracked{T}
ref::UInt32
f::Call
data::A
grad::A
TrackedArray{T,N,A}(f::Call, data::A) where {T,N,A} = new(0, f, data)
TrackedArray{T,N,A}(f::Call, data::A, grad::A) where {T,N,A} = new(0, f, data, grad)
data::T
grad::T
Tracked{T}(f::Call, data::T) where T = new(0, f, data)
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, data, grad)
end
TrackedScalar{T,A} = TrackedArray{T,0,A}
TrackedVector{T,A} = TrackedArray{T,1,A}
TrackedMatrix{T,A} = TrackedArray{T,2,A}
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(c, x)
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(c, x, Δ)
TrackedArray(c::Call) = TrackedArray(c, c())
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
isleaf(x::TrackedArray) = x.f == Call(nothing)
param(xs) = TrackedArray(map(x -> AbstractFloat(x), xs))
param(xs::Real) = param(fill(xs))
istracked(x::TrackedArray) = true
data(x::TrackedArray) = x.data
grad(x::TrackedArray) = x.grad
# Fallthrough methods
for f in :[Base.size, Base.ndims].args
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
end
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
similar(data(x), dims...)
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
# TODO decide if keeping both data and value. The problem is TrackedScalar
value(x) = x
value(x::TrackedArray) = data(x)
value(x::TrackedScalar) = data(x)[]
Base.:(==)(x::TrackedArray, y) = value(x) == y
Base.:(==)(y, x::TrackedArray) = y == value(x)
Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(y)
Base.isless(x::TrackedScalar, y) = isless(value(x), y)
Base.isless(x, y::TrackedScalar) = isless(x, value(y))
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...)
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}")
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
if repr
print(io, "param(")
Base.showarray(io, data(X), true)
print(io, ")")
else
header && print(io, "Tracked ")
Base.showarray(io, data(X), false, header = header)
end
end
Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`")
istracked(x::Tracked) = true
isleaf(x::Tracked) = x.f == Call(nothing)
data(x::Tracked) = x.data
grad(x::Tracked) = x.grad
include("back.jl")
include("lib.jl")
include("array.jl")
include("numeric.jl")
param(x::Number) = TrackedArray(fill(0))
Base.isless(x::TrackedScalar, y) = isless(x.data[], y)
Base.isless(x, y::TrackedScalar) = isless(x, y.data[])
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(x.data[], y.data[])
back!(x::TrackedScalar) = back!(x, 1)
param(xs::AbstractArray) = TrackedArray(map(x -> AbstractFloat(x), xs))
using DataFlow
using DataFlow: inputnode, constant

View File

@ -1,3 +1,65 @@
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
tracker::Tracked{A}
data::A
grad::A
TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data)
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
end
tracker(x::TrackedArray) = x.tracker
TrackedScalar{T,A} = TrackedArray{T,0,A}
TrackedVector{T,A} = TrackedArray{T,1,A}
TrackedMatrix{T,A} = TrackedArray{T,2,A}
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x)
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ)
TrackedArray(c::Call) = TrackedArray(c, c())
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}")
function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
if repr
print(io, "param(")
Base.showarray(io, data(X), true)
print(io, ")")
else
header && print(io, "Tracked ")
Base.showarray(io, data(X), false, header = header)
end
end
Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`")
# Fallthrough methods
for f in :[Base.size, Base.ndims].args
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
end
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
similar(data(x), dims...)
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
value(x) = data(x)
value(x::TrackedScalar) = data(x)[]
Base.:(==)(x::TrackedArray, y) = value(x) == y
Base.:(==)(y, x::TrackedArray) = y == value(x)
Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(y)
# Array Stdlib
toarray(xs::AbstractArray, ys::AbstractArray) = ys
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
@ -60,7 +122,6 @@ back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data)))
Base.sum(xs::TrackedScalar, dim...) = xs
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)

View File

@ -1,8 +1,6 @@
scan(x) = nothing
scan(c::Call) = foreach(scan, c.args)
function scan(x::TrackedArray)
function scan(x::Tracked)
ref = x.ref += 1
if ref == 1
scan(x.f)
@ -12,11 +10,16 @@ function scan(x::TrackedArray)
return
end
function scan(x)
istracked(x) && scan(tracker(x))
return
end
back_(f, y, args...) = back(f, args...)
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
back_(::Call{Void}, y, Δ) = nothing
function back(x::TrackedArray, Δ)
function back(x::Tracked, Δ)
ref = x.ref -= 1
if isdefined(x, :grad)
x.grad .+= Δ
@ -27,6 +30,8 @@ function back(x::TrackedArray, Δ)
return
end
back(x, Δ) = back(tracker(x), Δ)
macro back(x, Δ)
quote
x = $(esc(x))
@ -39,9 +44,9 @@ end
# TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update.
function back!(x::TrackedArray, Δ)
function back!(x::Tracked, Δ)
scan(x)
back(x, Δ)
end
back!(x::TrackedScalar) = back!(x, 1)
back!(x, Δ) = back!(tracker(x), Δ)

View File

@ -13,7 +13,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
@test gradtest(x -> softmax(x).*(1:3), 3)
@test gradtest(x -> softmax(x).*(1:3), (3,5))