seperate number type

This commit is contained in:
Mike J Innes 2018-02-07 20:39:36 +00:00
parent 282889970d
commit 79e4e25fea
5 changed files with 134 additions and 73 deletions

View File

@ -27,8 +27,8 @@ function train!(loss, data, opt; cb = () -> ())
opt = runall(opt)
@progress for d in data
l = loss(d...)
isinf(l.data[]) && error("Loss is Inf")
isnan(l.data[]) && error("Loss is NaN")
isinf(l) && error("Loss is Inf")
isnan(l) && error("Loss is NaN")
back!(l)
opt()
cb() == :stop && break

View File

@ -27,22 +27,24 @@ mutable struct Tracked{T}
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, data, grad)
end
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
track(f::Call, x) = Tracked(f, x)
track(f::Call) = track(f, f())
track(f, xs...) = track(Call(f, xs...))
istracked(x::Tracked) = true
isleaf(x::Tracked) = x.f == Call(nothing)
data(x::Tracked) = x.data
grad(x::Tracked) = x.grad
include("back.jl")
include("scalar.jl")
include("array.jl")
include("numeric.jl")
param(x::Number) = TrackedArray(fill(0))
Base.isless(x::TrackedScalar, y) = isless(x.data[], y)
Base.isless(x, y::TrackedScalar) = isless(x, y.data[])
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(x.data[], y.data[])
back!(x::TrackedScalar) = back!(x, 1)
param(xs::AbstractArray) = TrackedArray(map(x -> AbstractFloat(x), xs))
param(x::Number) = TrackedNumber(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs))
using DataFlow
using DataFlow: inputnode, constant

View File

@ -8,19 +8,18 @@ end
tracker(x::TrackedArray) = x.tracker
TrackedScalar{T,A} = TrackedArray{T,0,A}
TrackedVector{T,A} = TrackedArray{T,1,A}
TrackedMatrix{T,A} = TrackedArray{T,2,A}
TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
track(c::Call, x::AbstractArray) = TrackedArray(c, x)
TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x)
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ)
TrackedArray(c::Call) = TrackedArray(c, c())
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
@ -40,6 +39,8 @@ end
Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`")
back!(::TrackedArray) = error("Use back!(x, Δ)")
# Fallthrough methods
for f in :[Base.size, Base.ndims].args
@ -51,57 +52,47 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
value(x) = data(x)
value(x::TrackedScalar) = data(x)[]
Base.:(==)(x::TrackedArray, y) = value(x) == y
Base.:(==)(y, x::TrackedArray) = y == value(x)
Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(y)
Base.:(==)(x::TrackedArray, y) = data(x) == y
Base.:(==)(y, x::TrackedArray) = y == data(x)
Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y)
# Array Stdlib
toarray(xs::AbstractArray, ys::AbstractArray) = ys
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
unarray(xs) = xs
unarray(xs::AbstractArray{T,0} where T) = xs[]
Base.getindex(xs::TrackedArray, i...) =
TrackedArray(Call(getindex, xs, i...), toarray(xs.data, xs.data[i...]))
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
function back(::typeof(getindex), Δ, xs::TrackedArray, i...)
Δ′ = zeros(xs.data)
Δ′[i...] = unarray(Δ)
Δ′[i...] = Δ
@back(xs, Δ′)
end
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
Base.:-(xs::TrackedArray) = track(-, xs)
back(::typeof(-), Δ, xs::TrackedArray) = back(xs, -Δ)
Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs))
Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs))
Base.transpose(xs::TrackedArray) = track(transpose, xs)
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
back(::typeof(transpose), Δ, xs) = @back(xs, trim(xs, Δ.'))
back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
Base.repmat(x::TrackedVecOrMat, a::Integer...) = TrackedArray(Call(repmat, x, a...))
Base.repmat(x::TrackedVecOrMat, a::Int64...) = TrackedArray(Call(repmat, x, a...))
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedVector, b::TrackedVector...) = TrackedArray(Call(vcat, a, b...))
Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b)
Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...)
Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b)
Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b)
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = TrackedArray(Call(vcat, a, b...))
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...)
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b)
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = TrackedArray(Call(vcat, a, b...))
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b)
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
function back(::typeof(vcat), Δ, xs...)
i = Base.tail(map(_ -> :, size(Δ)))
@ -113,27 +104,27 @@ function back(::typeof(vcat), Δ, xs...)
end
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
TrackedArray(Call(reshape, xs, dims...))
track(reshape, xs, dims...)
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
back(xs, reshape(Δ, size(xs)))
# Reductions
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data)))
Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
Base.sum(xs::TrackedArray) = track(sum, xs)
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
Base.mean(xs::TrackedArray) = track(mean, xs)
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys))))
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
function back(::typeof(dot), Δ, xs, ys)
@back(xs, Δ.*ys)
@ -152,23 +143,23 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) =
# BLAS
Base.diagm(x::TrackedVector) = TrackedArray(Call(diagm, x))
Base.diagm(x::TrackedVector) = track(diagm, x)
back(::typeof(diagm), Δ, x) = @back(x, diag(Δ))
for f in :[*, Ac_mul_B, A_mul_Bc].args
@eval begin
import Base.$f
$f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
$f(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call($f, a, b))
$f(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b))
$f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b)
$f(a::TrackedMatrix, b::AbstractMatrix) = track($f, a, b)
$f(a::AbstractMatrix, b::TrackedMatrix) = track($f, a, b)
$f(a::TrackedMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedMatrix, b::AbstractVector) = TrackedArray(Call($f, a, b))
$f(a::AbstractMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedMatrix, b::TrackedVector) = track($f, a, b)
$f(a::TrackedMatrix, b::AbstractVector) = track($f, a, b)
$f(a::AbstractMatrix, b::TrackedVector) = track($f, a, b)
$f(a::TrackedVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedVector, b::AbstractVector) = TrackedArray(Call($f, a, b))
$f(a::AbstractVector, b::TrackedVector) = TrackedArray(Call($f, a, b))
$f(a::TrackedVector, b::TrackedVector) = track($f, a, b)
$f(a::TrackedVector, b::AbstractVector) = track($f, a, b)
$f(a::AbstractVector, b::TrackedVector) = track($f, a, b)
end
end
@ -202,11 +193,11 @@ end
using NNlib
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv2d, pool
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
softmax(xs::TrackedArray) = track(softmax, xs)
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs)))
logsoftmax(xs::TrackedArray) = TrackedArray(Call(logsoftmax, xs))
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
@ -214,11 +205,11 @@ back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)
conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
track(_conv2d, x, w, stride, padding)
conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
track(_conv2d, x, w, stride, padding)
conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1, padding = 0) =
TrackedArray(Call(_conv2d, x, w, stride, padding))
track(_conv2d, x, w, stride, padding)
function back(::typeof(_conv2d), Δ, x, w, stride, pad)
@back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride, padding = pad))
@ -228,7 +219,7 @@ end
_pool(x, k, pad, mode) = pool(x, window = k, mode = mode, padding = pad)
pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0, padding = 0) =
TrackedArray(Call(_pool, x, window, padding, mode))
track(_pool, x, window, padding, mode)
back_(::typeof(_pool), y, Δ, x, k, pad, mode) =
back(x, NNlib.pool_grad(data(x), y, Δ, window=k, mode=mode, padding=pad))
@ -246,23 +237,24 @@ end
dualify(xs, n) = xs
dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs))
dualify(xs::TrackedNumber, ps) = Dual(data(xs), ps)
function tracked_broadcast(f, args::Vararg{Any,N}) where N
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
out = broadcast(f, dargs...)
eltype(out) <: Dual || return out
# TrackedArray(Call(Broadcasted(f, broadcast(f, dargs...)), args...))
# Works around a 0.6 type inference issue
b = Broadcasted(f, out)
TrackedArray(Call(b, args...), b())
track(Call(b, args...), b())
end
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)}))
unbroadcast(x, Δ) =
unbroadcast(x::AbstractArray, Δ) =
size(x) == size(Δ) ? Δ :
trim(x, sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ))))
unbroadcast(x::Number, Δ) = sum(Δ)
function getpartial(Δ, x, i)
@inbounds p = getindex(partials(x), i)
return Δ * p

View File

@ -19,10 +19,13 @@ back_(f, y, args...) = back(f, args...)
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...)
back_(::Call{Void}, y, Δ) = nothing
accum!(x::Tracked, Δ) = (x.grad += Δ)
accum!(x::Tracked{<:AbstractArray}, Δ) = (x.grad .+= Δ)
function back(x::Tracked, Δ)
ref = x.ref -= 1
if isdefined(x, :grad)
x.grad .+= Δ
accum!(x, Δ)
ref == 0 && back_(x.f, x.data, x.grad)
else
ref == 0 && back_(x.f, x.data, Δ)
@ -31,6 +34,7 @@ function back(x::Tracked, Δ)
end
back(x, Δ) = back(tracker(x), Δ)
back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
macro back(x, Δ)
quote

63
src/tracker/scalar.jl Normal file
View File

@ -0,0 +1,63 @@
struct TrackedNumber{T<:Number} <: Number
tracker::Tracked{T}
end
TrackedNumber(x::Number) = TrackedNumber(Tracked(Call(nothing), x))
tracker(x::TrackedNumber) = x.tracker
track(f::Call, x::Number) = TrackedNumber(Tracked(f, x))
back!(x::TrackedNumber) = back!(x, 1)
function Base.show(io::IO, x::TrackedNumber)
show(io, data(x))
print(io, " (tracked)")
end
Base.convert(::Type{TrackedNumber{T}}, x::TrackedNumber{T}) where T = x
Base.convert(::Type{TrackedNumber{T}}, x::TrackedNumber) where T =
TrackedNumber(Tracked(x.tracker.f, convert(T, x.tracker.data)))
Base.convert(::Type{TrackedNumber{T}}, x::Number) where T = TrackedNumber(convert(T, x))
Base.isless(x::TrackedNumber, y::Number) = isless(data(x), y)
Base.isless(x::Number, y::TrackedNumber) = isless(x, data(y))
Base.isless(x::TrackedNumber, y::TrackedNumber) = isless(data(x), data(y))
Base.:(==)(x::TrackedNumber, y::Number) = data(x) == y
Base.:(==)(x::Number, y::TrackedNumber) = x == data(y)
Base.:(==)(x::TrackedNumber, y::TrackedNumber) = data(x) == data(y)
for f in :[isinf, isnan].args
@eval Base.$f(x::TrackedNumber) = isinf(data(x))
end
Base.promote_rule(::Type{TrackedNumber{S}},::Type{T}) where {S,T} =
TrackedNumber{promote_type(S,T)}
using DiffRules, SpecialFunctions, NaNMath
for (M, f, arity) in DiffRules.diffrules()
arity == 1 || continue
@eval begin
$M.$f(a::TrackedNumber) = track($M.$f, a)
back(::typeof($M.$f), Δ::Number, a::TrackedNumber) =
back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a)))))
end
end
for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue
da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b)))
@eval begin
$M.$f(a::TrackedNumber, b::TrackedNumber) = track($M.$f, a, b)
$M.$f(a::TrackedNumber, b::Number) = track($M.$f, a, b)
$M.$f(a::Number, b::TrackedNumber) = track($M.$f, a, b)
function back(::typeof($M.$f), Δ::Number, a::Number, b::Number)
@back(a, Δ * $da)
@back(b, Δ * $db)
end
end
end