From 282889970dcdefc73aeaf41eb18e5b08dfabb879 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 7 Feb 2018 17:43:25 +0000 Subject: [PATCH 1/4] seperate tracking infrastructure from array wrapper --- src/Flux.jl | 2 +- src/optimise/train.jl | 6 +- src/tracker/Tracker.jl | 97 ++++++++------------------------ src/tracker/{lib.jl => array.jl} | 63 ++++++++++++++++++++- src/tracker/back.jl | 17 ++++-- test/tracker.jl | 2 +- 6 files changed, 102 insertions(+), 85 deletions(-) rename src/tracker/{lib.jl => array.jl} (82%) diff --git a/src/Flux.jl b/src/Flux.jl index 2e124655..87a37566 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -19,7 +19,7 @@ export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax, include("tracker/Tracker.jl") using .Tracker export Tracker -import .Tracker: data, value +import .Tracker: data include("optimise/Optimise.jl") using .Optimise diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 31812fa0..cb4d1c91 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,5 +1,5 @@ using Juno -using Flux.Tracker: back!, value +using Flux.Tracker: back! runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs) @@ -27,8 +27,8 @@ function train!(loss, data, opt; cb = () -> ()) opt = runall(opt) @progress for d in data l = loss(d...) - isinf(value(l)) && error("Loss is Inf") - isnan(value(l)) && error("Loss is NaN") + isinf(l.data[]) && error("Loss is Inf") + isnan(l.data[]) && error("Loss is NaN") back!(l) opt() cb() == :stop && break diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 0b8ee3cd..fa01060a 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -2,8 +2,12 @@ module Tracker export TrackedArray, TrackedVector, TrackedMatrix, param, back! -data(x) = x -istracked(x) = false +tracker(x) = nothing + +istracked(x) = tracker(x) ≠ nothing +isleaf(x) = !istracked(x) || isleaf(tracker(x)) +data(x) = istracked(x) ? data(tracker(x)) : x +grad(x) = grad(tracker(x)) struct Call{F,As<:Tuple} func::F @@ -14,85 +18,32 @@ Call(f, args...) = Call{typeof(f),typeof(args)}(f, args) (c::Call)() = c.func(data.(c.args)...) -mutable struct TrackedArray{T,N,A} <: AbstractArray{T,N} +mutable struct Tracked{T} ref::UInt32 f::Call - data::A - grad::A - TrackedArray{T,N,A}(f::Call, data::A) where {T,N,A} = new(0, f, data) - TrackedArray{T,N,A}(f::Call, data::A, grad::A) where {T,N,A} = new(0, f, data, grad) + data::T + grad::T + Tracked{T}(f::Call, data::T) where T = new(0, f, data) + Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, data, grad) end -TrackedScalar{T,A} = TrackedArray{T,0,A} -TrackedVector{T,A} = TrackedArray{T,1,A} -TrackedMatrix{T,A} = TrackedArray{T,2,A} -TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}} - -TrackedArray(c::Call, x::A) where A <: AbstractArray = - TrackedArray{eltype(A),ndims(A),A}(c, x) - -TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = - TrackedArray{eltype(A),ndims(A),A}(c, x, Δ) - -TrackedArray(c::Call) = TrackedArray(c, c()) - -TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) - -isleaf(x::TrackedArray) = x.f == Call(nothing) - -param(xs) = TrackedArray(map(x -> AbstractFloat(x), xs)) -param(xs::Real) = param(fill(xs)) - -istracked(x::TrackedArray) = true -data(x::TrackedArray) = x.data -grad(x::TrackedArray) = x.grad - -# Fallthrough methods - -for f in :[Base.size, Base.ndims].args - @eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...) -end - -Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = - similar(data(x), dims...) - -Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) - -# TODO decide if keeping both data and value. The problem is TrackedScalar -value(x) = x -value(x::TrackedArray) = data(x) -value(x::TrackedScalar) = data(x)[] - -Base.:(==)(x::TrackedArray, y) = value(x) == y -Base.:(==)(y, x::TrackedArray) = y == value(x) -Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(y) - -Base.isless(x::TrackedScalar, y) = isless(value(x), y) -Base.isless(x, y::TrackedScalar) = isless(x, value(y)) -Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y)) -Base.isapprox(x::TrackedScalar, y; kws...) = isapprox(x.data[], y; kws...) - -Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = - print(io, "TrackedArray{…,$A}") - -function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true) - if repr - print(io, "param(") - Base.showarray(io, data(X), true) - print(io, ")") - else - header && print(io, "Tracked ") - Base.showarray(io, data(X), false, header = header) - end -end - -Base.setindex!(xs::TrackedArray, v, i...) = - error("Can't differentiate `setindex!`") +istracked(x::Tracked) = true +isleaf(x::Tracked) = x.f == Call(nothing) +data(x::Tracked) = x.data +grad(x::Tracked) = x.grad include("back.jl") -include("lib.jl") +include("array.jl") include("numeric.jl") +param(x::Number) = TrackedArray(fill(0)) +Base.isless(x::TrackedScalar, y) = isless(x.data[], y) +Base.isless(x, y::TrackedScalar) = isless(x, y.data[]) +Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(x.data[], y.data[]) +back!(x::TrackedScalar) = back!(x, 1) + +param(xs::AbstractArray) = TrackedArray(map(x -> AbstractFloat(x), xs)) + using DataFlow using DataFlow: inputnode, constant diff --git a/src/tracker/lib.jl b/src/tracker/array.jl similarity index 82% rename from src/tracker/lib.jl rename to src/tracker/array.jl index ceb43aea..6bc06d57 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/array.jl @@ -1,3 +1,65 @@ +struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N} + tracker::Tracked{A} + data::A + grad::A + TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data) + TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad) +end + +tracker(x::TrackedArray) = x.tracker + +TrackedScalar{T,A} = TrackedArray{T,0,A} +TrackedVector{T,A} = TrackedArray{T,1,A} +TrackedMatrix{T,A} = TrackedArray{T,2,A} +TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}} + +TrackedArray(c::Call, x::A) where A <: AbstractArray = + TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x) + +TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = + TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ) + +TrackedArray(c::Call) = TrackedArray(c, c()) + +TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) + +Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = + print(io, "TrackedArray{…,$A}") + +function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true) + if repr + print(io, "param(") + Base.showarray(io, data(X), true) + print(io, ")") + else + header && print(io, "Tracked ") + Base.showarray(io, data(X), false, header = header) + end +end + +Base.setindex!(xs::TrackedArray, v, i...) = + error("Can't differentiate `setindex!`") + +# Fallthrough methods + +for f in :[Base.size, Base.ndims].args + @eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...) +end + +Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = + similar(data(x), dims...) + +Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) + +value(x) = data(x) +value(x::TrackedScalar) = data(x)[] + +Base.:(==)(x::TrackedArray, y) = value(x) == y +Base.:(==)(y, x::TrackedArray) = y == value(x) +Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(y) + +# Array Stdlib + toarray(xs::AbstractArray, ys::AbstractArray) = ys toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y @@ -60,7 +122,6 @@ back(::typeof(reshape), Δ, xs::TrackedArray, _...) = Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim)) Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data))) -Base.sum(xs::TrackedScalar, dim...) = xs back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ) diff --git a/src/tracker/back.jl b/src/tracker/back.jl index b4cd27c6..a01e9313 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -1,8 +1,6 @@ -scan(x) = nothing - scan(c::Call) = foreach(scan, c.args) -function scan(x::TrackedArray) +function scan(x::Tracked) ref = x.ref += 1 if ref == 1 scan(x.f) @@ -12,11 +10,16 @@ function scan(x::TrackedArray) return end +function scan(x) + istracked(x) && scan(tracker(x)) + return +end + back_(f, y, args...) = back(f, args...) back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...) back_(::Call{Void}, y, Δ) = nothing -function back(x::TrackedArray, Δ) +function back(x::Tracked, Δ) ref = x.ref -= 1 if isdefined(x, :grad) x.grad .+= Δ @@ -27,6 +30,8 @@ function back(x::TrackedArray, Δ) return end +back(x, Δ) = back(tracker(x), Δ) + macro back(x, Δ) quote x = $(esc(x)) @@ -39,9 +44,9 @@ end # TODO: if an error occurs in `back` the refcounts will be broken # and `back` will silently fail to update. -function back!(x::TrackedArray, Δ) +function back!(x::Tracked, Δ) scan(x) back(x, Δ) end -back!(x::TrackedScalar) = back!(x, 1) +back!(x, Δ) = back!(tracker(x), Δ) diff --git a/test/tracker.jl b/test/tracker.jl index 4f865957..199f45ce 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -13,7 +13,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest((w, x) -> w'*x, randn(10, 2), randn(10)) @test gradtest((w, x) -> w*x', randn(5,5), randn(5,5)) -@test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5)) +@test gradtest(x -> sum(x, (2, 3)), (3,4,5)) @test gradtest(x -> softmax(x).*(1:3), 3) @test gradtest(x -> softmax(x).*(1:3), (3,5)) From 79e4e25fea597d65de249eddcb3266507403ce74 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 7 Feb 2018 20:39:36 +0000 Subject: [PATCH 2/4] seperate number type --- src/optimise/train.jl | 4 +- src/tracker/Tracker.jl | 16 +++--- src/tracker/array.jl | 118 +++++++++++++++++++---------------------- src/tracker/back.jl | 6 ++- src/tracker/scalar.jl | 63 ++++++++++++++++++++++ 5 files changed, 134 insertions(+), 73 deletions(-) create mode 100644 src/tracker/scalar.jl diff --git a/src/optimise/train.jl b/src/optimise/train.jl index cb4d1c91..d29ec123 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -27,8 +27,8 @@ function train!(loss, data, opt; cb = () -> ()) opt = runall(opt) @progress for d in data l = loss(d...) - isinf(l.data[]) && error("Loss is Inf") - isnan(l.data[]) && error("Loss is NaN") + isinf(l) && error("Loss is Inf") + isnan(l) && error("Loss is NaN") back!(l) opt() cb() == :stop && break diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index fa01060a..472441af 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -27,22 +27,24 @@ mutable struct Tracked{T} Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, data, grad) end +Tracked(f::Call, x) = Tracked{typeof(x)}(f, x) + +track(f::Call, x) = Tracked(f, x) +track(f::Call) = track(f, f()) +track(f, xs...) = track(Call(f, xs...)) + istracked(x::Tracked) = true isleaf(x::Tracked) = x.f == Call(nothing) data(x::Tracked) = x.data grad(x::Tracked) = x.grad include("back.jl") +include("scalar.jl") include("array.jl") include("numeric.jl") -param(x::Number) = TrackedArray(fill(0)) -Base.isless(x::TrackedScalar, y) = isless(x.data[], y) -Base.isless(x, y::TrackedScalar) = isless(x, y.data[]) -Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(x.data[], y.data[]) -back!(x::TrackedScalar) = back!(x, 1) - -param(xs::AbstractArray) = TrackedArray(map(x -> AbstractFloat(x), xs)) +param(x::Number) = TrackedNumber(float(x)) +param(xs::AbstractArray) = TrackedArray(float.(xs)) using DataFlow using DataFlow: inputnode, constant diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 6bc06d57..93ec7bce 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -8,19 +8,18 @@ end tracker(x::TrackedArray) = x.tracker -TrackedScalar{T,A} = TrackedArray{T,0,A} TrackedVector{T,A} = TrackedArray{T,1,A} TrackedMatrix{T,A} = TrackedArray{T,2,A} TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}} +track(c::Call, x::AbstractArray) = TrackedArray(c, x) + TrackedArray(c::Call, x::A) where A <: AbstractArray = TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x) TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ) -TrackedArray(c::Call) = TrackedArray(c, c()) - TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = @@ -40,6 +39,8 @@ end Base.setindex!(xs::TrackedArray, v, i...) = error("Can't differentiate `setindex!`") +back!(::TrackedArray) = error("Use back!(x, Δ)") + # Fallthrough methods for f in :[Base.size, Base.ndims].args @@ -51,57 +52,47 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) -value(x) = data(x) -value(x::TrackedScalar) = data(x)[] - -Base.:(==)(x::TrackedArray, y) = value(x) == y -Base.:(==)(y, x::TrackedArray) = y == value(x) -Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(y) +Base.:(==)(x::TrackedArray, y) = data(x) == y +Base.:(==)(y, x::TrackedArray) = y == data(x) +Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y) # Array Stdlib -toarray(xs::AbstractArray, ys::AbstractArray) = ys -toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y - -unarray(xs) = xs -unarray(xs::AbstractArray{T,0} where T) = xs[] - -Base.getindex(xs::TrackedArray, i...) = - TrackedArray(Call(getindex, xs, i...), toarray(xs.data, xs.data[i...])) +Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) function back(::typeof(getindex), Δ, xs::TrackedArray, i...) Δ′ = zeros(xs.data) - Δ′[i...] = unarray(Δ) + Δ′[i...] = Δ @back(xs, Δ′) end -Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs)) +Base.:-(xs::TrackedArray) = track(-, xs) back(::typeof(-), Δ, xs::TrackedArray) = back(xs, -Δ) -Base.transpose(xs::TrackedArray) = TrackedArray(Call(transpose, xs)) -Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs)) +Base.transpose(xs::TrackedArray) = track(transpose, xs) +Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs) back(::typeof(transpose), Δ, xs) = @back(xs, trim(xs, Δ.')) back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ')) -Base.repmat(x::TrackedVecOrMat, a::Integer...) = TrackedArray(Call(repmat, x, a...)) -Base.repmat(x::TrackedVecOrMat, a::Int64...) = TrackedArray(Call(repmat, x, a...)) +Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) +Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) -Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b)) -Base.vcat(a::TrackedVector, b::TrackedVector...) = TrackedArray(Call(vcat, a, b...)) -Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b)) -Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b)) +Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b) +Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...) +Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b) +Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b) -Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b)) -Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = TrackedArray(Call(vcat, a, b...)) -Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = TrackedArray(Call(vcat, a, b)) -Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b)) +Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b) +Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...) +Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b) +Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b) -Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b)) -Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = TrackedArray(Call(vcat, a, b...)) -Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call(vcat, a, b)) -Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b)) +Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b) +Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...) +Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b) +Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b) function back(::typeof(vcat), Δ, xs...) i = Base.tail(map(_ -> :, size(Δ))) @@ -113,27 +104,27 @@ function back(::typeof(vcat), Δ, xs...) end Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = - TrackedArray(Call(reshape, xs, dims...)) + track(reshape, xs, dims...) back(::typeof(reshape), Δ, xs::TrackedArray, _...) = back(xs, reshape(Δ, size(xs))) # Reductions -Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim)) -Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.data, sum(xs.data))) +Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim) +Base.sum(xs::TrackedArray) = track(sum, xs) back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ) Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) -Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data))) -Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region)) +Base.mean(xs::TrackedArray) = track(mean, xs) +Base.mean(xs::TrackedArray, region) = track(mean, xs, region) -LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys)))) -LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys)))) -LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = TrackedArray(Call(dot, xs, ys), toarray(xs.data, dot(data(xs), data(ys)))) +LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys) +LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys) +LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys) function back(::typeof(dot), Δ, xs, ys) @back(xs, Δ.*ys) @@ -152,23 +143,23 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) = # BLAS -Base.diagm(x::TrackedVector) = TrackedArray(Call(diagm, x)) +Base.diagm(x::TrackedVector) = track(diagm, x) back(::typeof(diagm), Δ, x) = @back(x, diag(Δ)) for f in :[*, Ac_mul_B, A_mul_Bc].args @eval begin import Base.$f - $f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b)) - $f(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call($f, a, b)) - $f(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b)) + $f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b) + $f(a::TrackedMatrix, b::AbstractMatrix) = track($f, a, b) + $f(a::AbstractMatrix, b::TrackedMatrix) = track($f, a, b) - $f(a::TrackedMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b)) - $f(a::TrackedMatrix, b::AbstractVector) = TrackedArray(Call($f, a, b)) - $f(a::AbstractMatrix, b::TrackedVector) = TrackedArray(Call($f, a, b)) + $f(a::TrackedMatrix, b::TrackedVector) = track($f, a, b) + $f(a::TrackedMatrix, b::AbstractVector) = track($f, a, b) + $f(a::AbstractMatrix, b::TrackedVector) = track($f, a, b) - $f(a::TrackedVector, b::TrackedVector) = TrackedArray(Call($f, a, b)) - $f(a::TrackedVector, b::AbstractVector) = TrackedArray(Call($f, a, b)) - $f(a::AbstractVector, b::TrackedVector) = TrackedArray(Call($f, a, b)) + $f(a::TrackedVector, b::TrackedVector) = track($f, a, b) + $f(a::TrackedVector, b::AbstractVector) = track($f, a, b) + $f(a::AbstractVector, b::TrackedVector) = track($f, a, b) end end @@ -202,11 +193,11 @@ end using NNlib import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv2d, pool -softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs)) +softmax(xs::TrackedArray) = track(softmax, xs) back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs))) -logsoftmax(xs::TrackedArray) = TrackedArray(Call(logsoftmax, xs)) +logsoftmax(xs::TrackedArray) = track(logsoftmax, xs) back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs))) @@ -214,11 +205,11 @@ back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs))) _conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad) conv2d(x::TrackedArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) = - TrackedArray(Call(_conv2d, x, w, stride, padding)) + track(_conv2d, x, w, stride, padding) conv2d(x::AbstractArray{<:Any,4}, w::TrackedArray{<:Any,4}; stride = 1, padding = 0) = - TrackedArray(Call(_conv2d, x, w, stride, padding)) + track(_conv2d, x, w, stride, padding) conv2d(x::TrackedArray{<:Any,4}, w::AbstractArray{<:Any,4}; stride = 1, padding = 0) = - TrackedArray(Call(_conv2d, x, w, stride, padding)) + track(_conv2d, x, w, stride, padding) function back(::typeof(_conv2d), Δ, x, w, stride, pad) @back(x, NNlib.conv2d_grad_x(data(x), data(w), Δ; stride = stride, padding = pad)) @@ -228,7 +219,7 @@ end _pool(x, k, pad, mode) = pool(x, window = k, mode = mode, padding = pad) pool(x::TrackedArray{<:Any,4}; window = 2, mode = 0, padding = 0) = - TrackedArray(Call(_pool, x, window, padding, mode)) + track(_pool, x, window, padding, mode) back_(::typeof(_pool), y, Δ, x, k, pad, mode) = back(x, NNlib.pool_grad(data(x), y, Δ, window=k, mode=mode, padding=pad)) @@ -246,23 +237,24 @@ end dualify(xs, n) = xs dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs)) +dualify(xs::TrackedNumber, ps) = Dual(data(xs), ps) function tracked_broadcast(f, args::Vararg{Any,N}) where N dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N})) out = broadcast(f, dargs...) eltype(out) <: Dual || return out - # TrackedArray(Call(Broadcasted(f, broadcast(f, dargs...)), args...)) - # Works around a 0.6 type inference issue b = Broadcasted(f, out) - TrackedArray(Call(b, args...), b()) + track(Call(b, args...), b()) end trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)})) -unbroadcast(x, Δ) = +unbroadcast(x::AbstractArray, Δ) = size(x) == size(Δ) ? Δ : trim(x, sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ)))) +unbroadcast(x::Number, Δ) = sum(Δ) + function getpartial(Δ, x, i) @inbounds p = getindex(partials(x), i) return Δ * p diff --git a/src/tracker/back.jl b/src/tracker/back.jl index a01e9313..37c233e1 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -19,10 +19,13 @@ back_(f, y, args...) = back(f, args...) back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...) back_(::Call{Void}, y, Δ) = nothing +accum!(x::Tracked, Δ) = (x.grad += Δ) +accum!(x::Tracked{<:AbstractArray}, Δ) = (x.grad .+= Δ) + function back(x::Tracked, Δ) ref = x.ref -= 1 if isdefined(x, :grad) - x.grad .+= Δ + accum!(x, Δ) ref == 0 && back_(x.f, x.data, x.grad) else ref == 0 && back_(x.f, x.data, Δ) @@ -31,6 +34,7 @@ function back(x::Tracked, Δ) end back(x, Δ) = back(tracker(x), Δ) +back(x::Void, Δ) = error("Can't backpropagate through `nothing`") macro back(x, Δ) quote diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl new file mode 100644 index 00000000..026d2aeb --- /dev/null +++ b/src/tracker/scalar.jl @@ -0,0 +1,63 @@ +struct TrackedNumber{T<:Number} <: Number + tracker::Tracked{T} +end + +TrackedNumber(x::Number) = TrackedNumber(Tracked(Call(nothing), x)) + +tracker(x::TrackedNumber) = x.tracker + +track(f::Call, x::Number) = TrackedNumber(Tracked(f, x)) + +back!(x::TrackedNumber) = back!(x, 1) + +function Base.show(io::IO, x::TrackedNumber) + show(io, data(x)) + print(io, " (tracked)") +end + +Base.convert(::Type{TrackedNumber{T}}, x::TrackedNumber{T}) where T = x + +Base.convert(::Type{TrackedNumber{T}}, x::TrackedNumber) where T = + TrackedNumber(Tracked(x.tracker.f, convert(T, x.tracker.data))) + +Base.convert(::Type{TrackedNumber{T}}, x::Number) where T = TrackedNumber(convert(T, x)) + +Base.isless(x::TrackedNumber, y::Number) = isless(data(x), y) +Base.isless(x::Number, y::TrackedNumber) = isless(x, data(y)) +Base.isless(x::TrackedNumber, y::TrackedNumber) = isless(data(x), data(y)) + +Base.:(==)(x::TrackedNumber, y::Number) = data(x) == y +Base.:(==)(x::Number, y::TrackedNumber) = x == data(y) +Base.:(==)(x::TrackedNumber, y::TrackedNumber) = data(x) == data(y) + +for f in :[isinf, isnan].args + @eval Base.$f(x::TrackedNumber) = isinf(data(x)) +end + +Base.promote_rule(::Type{TrackedNumber{S}},::Type{T}) where {S,T} = + TrackedNumber{promote_type(S,T)} + +using DiffRules, SpecialFunctions, NaNMath + +for (M, f, arity) in DiffRules.diffrules() + arity == 1 || continue + @eval begin + $M.$f(a::TrackedNumber) = track($M.$f, a) + back(::typeof($M.$f), Δ::Number, a::TrackedNumber) = + back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a))))) + end +end + +for (M, f, arity) in DiffRules.diffrules() + arity == 2 || continue + da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b))) + @eval begin + $M.$f(a::TrackedNumber, b::TrackedNumber) = track($M.$f, a, b) + $M.$f(a::TrackedNumber, b::Number) = track($M.$f, a, b) + $M.$f(a::Number, b::TrackedNumber) = track($M.$f, a, b) + function back(::typeof($M.$f), Δ::Number, a::Number, b::Number) + @back(a, Δ * $da) + @back(b, Δ * $db) + end + end +end From 39f7f8fdf3b011e4f782b6635ed2507b41b1882b Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 7 Feb 2018 22:20:44 +0000 Subject: [PATCH 3/4] tracked tuples --- src/tracker/back.jl | 10 ++++++---- src/tracker/scalar.jl | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 37c233e1..e9bf28e0 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -1,3 +1,5 @@ +init_grad(x) = zero(x) + scan(c::Call) = foreach(scan, c.args) function scan(x::Tracked) @@ -5,7 +7,7 @@ function scan(x::Tracked) if ref == 1 scan(x.f) else - isdefined(x, :grad) || (x.grad = zeros(x.data)) + isdefined(x, :grad) || (x.grad = init_grad(x.data)) end return end @@ -19,13 +21,13 @@ back_(f, y, args...) = back(f, args...) back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...) back_(::Call{Void}, y, Δ) = nothing -accum!(x::Tracked, Δ) = (x.grad += Δ) -accum!(x::Tracked{<:AbstractArray}, Δ) = (x.grad .+= Δ) +accum!(x, Δ) = x .+ Δ +accum!(x::AbstractArray, Δ) = (x .+= Δ) function back(x::Tracked, Δ) ref = x.ref -= 1 if isdefined(x, :grad) - accum!(x, Δ) + x.grad = accum!(x.grad, Δ) ref == 0 && back_(x.f, x.data, x.grad) else ref == 0 && back_(x.f, x.data, Δ) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 026d2aeb..f37f8c73 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -61,3 +61,28 @@ for (M, f, arity) in DiffRules.diffrules() end end end + +# Tuples + +struct TrackedTuple{T<:Tuple} + tracker::Tracked{T} +end + +tracker(xs::TrackedTuple) = xs.tracker + +accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ) +init_grad(x::Tuple) = init_grad.(x) + +track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs)) + +function Base.show(io::IO, xs::TrackedTuple) + show(io, data(xs)) + print(io, " (tracked)") +end + +Base.length(x::TrackedTuple) = length(data(x)) + +Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i) + +back(::typeof(getindex), Δ, t, i) = + back(t, ntuple(j -> i == j ? Δ : 0, length(t))) From 0ac924e8e1cdd62410b7150c4726293c16ffa808 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 7 Feb 2018 22:52:46 +0000 Subject: [PATCH 4/4] fixups --- REQUIRE | 7 ++++++- src/tracker/Tracker.jl | 13 +++++++------ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/REQUIRE b/REQUIRE index eb7545da..b31bc6ad 100644 --- a/REQUIRE +++ b/REQUIRE @@ -3,8 +3,13 @@ DataFlow 0.2.1 Juno MacroTools 0.3.3 NNlib -ForwardDiff 0.5.0 Requires Adapt GZip Colors + +# AD +ForwardDiff 0.5.0 +DiffRules +SpecialFunctions +NaNMath diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 472441af..96ed3bcf 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -52,17 +52,18 @@ using DataFlow: inputnode, constant vcall(f, args...) = vertex(DataFlow.Call(), constant(f), args...) vcall(f::Broadcasted, args...) = vcall(broadcast, constant(f.f), args...) -function _graph(x::TrackedArray, inputs::TrackedArray...; cache = ObjectIdDict()) +_graph(x::Tracked, inputs...; cache = ObjectIdDict()) = + vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...) + +function _graph(x, inputs...; cache = ObjectIdDict()) haskey(cache, x) && return cache[x] i = findfirst(inputs, x) cache[x] = i > 0 ? inputnode(i) : - isleaf(x) ? constant(x) : - vcall(x.f.func, map(x -> _graph(x, inputs...; cache = cache), x.f.args)...) + istracked(x) ? _graph(tracker(x), inputs...; cache = cache) : + constant(x) end -_graph(x, inputs::TrackedArray...; cache = ObjectIdDict()) = constant(x) - function graph(f, args...) inputs = param.(args) _graph(f(inputs...), inputs...) @@ -70,6 +71,6 @@ end import Adapt.adapt -adapt(T, xs::TrackedArray) = TrackedArray(xs.f, adapt(T, xs.data), adapt(T, xs.grad)) +adapt(T, xs::TrackedArray) = param(adapt(T, data(xs))) end