From 41b9412439192a93934306ebf816afe9b9b652b4 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 6 Jul 2018 11:28:18 +0100 Subject: [PATCH 1/9] new grad api --- src/tracker/Tracker.jl | 29 +++++++++++++++++++++++------ src/tracker/array.jl | 4 ++-- src/tracker/back.jl | 15 ++++++++++----- src/tracker/scalar.jl | 18 ++++++++---------- 4 files changed, 43 insertions(+), 23 deletions(-) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 1296d179..959fc8f1 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -1,5 +1,7 @@ module Tracker +using MacroTools + import Base: == export TrackedArray, TrackedVector, TrackedMatrix, param, back! @@ -17,7 +19,8 @@ struct Call{F,As<:Tuple} args::As end -Call(f, args...) = Call{typeof(f),typeof(args)}(f, args) +Call(f, args) = Call{typeof(f),typeof(args)}(f, args) +Call() = Call(nothing, ()) # When deserialising, the object_id changes a::Call == b::Call = a.func == b.func && a.args == b.args @@ -38,15 +41,29 @@ end Tracked(f::Call, x) = Tracked{typeof(x)}(f, x) Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ) -track(f::Call, x) = Tracked(f, x) -track(f::Call) = track(f, f()) -track(f, xs...) = track(Call(f, xs...)) - istracked(x::Tracked) = true -isleaf(x::Tracked) = x.f == Call(nothing) +isleaf(x::Tracked) = x.f == Call() data(x::Tracked) = x.data grad(x::Tracked) = x.grad +track(f::Call, x) = Tracked(f, x) +track(f::Call) = track(f, f()) + +function _forward end + +function track(f, xs...) + y, back = _forward(f, data.(xs)...) + track(Call(back, xs), y) +end + +macro grad(ex) + @capture(shortdef(ex), (name_(args__) = body_) | + (name_(args__) where {T__} = body_)) || error("Need a function definition") + T == nothing && (T = []) + unshift!(args, :(::typeof($name))) + :(Tracker._forward($(args...)) where $(T...) = $body) |> esc +end + function update!(x, Δ) tracker(x).data += Δ tracker(x).grad .= 0 diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 7a54d2eb..987630c7 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -20,7 +20,7 @@ TrackedArray(c::Call, x::A) where A <: AbstractArray = TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ) -TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) +TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x)) Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T} @@ -101,7 +101,7 @@ function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer) Δ′ = similar(xs.data) Δ′ .= 0 S = size(xs.data) - + # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ for (dest_idx, val) in enumerate(IndexCartesian(), Δ) # First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 60b12868..5bf13d56 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -21,9 +21,14 @@ function scan(x) return end -back_(f, y, args...) = back(f, args...) -back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...) -back_(::Call{Void}, y, Δ) = nothing +function back_(c::Call, Δ) + Δs = c.func(Δ) + (Δs isa Tuple && length(Δs) == length(c.args)) || + error("Gradient is not a tuple of length $(length(c.args))") + foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args, Δs) +end + +back_(::Call{Void}, Δ) = nothing accum!(x, Δ) = x .+ Δ accum!(x::AbstractArray, Δ) = (x .+= Δ) @@ -33,9 +38,9 @@ function back(x::Tracked, Δ) ref = x.ref -= 1 if isdefined(x, :grad) x.grad = accum!(x.grad, Δ) - ref == 0 && back_(x.f, x.data, x.grad) + ref == 0 && back_(x.f, x.grad) else - ref == 0 && back_(x.f, x.data, Δ) + ref == 0 && back_(x.f, Δ) end return end diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 773943c0..9d2d724f 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -2,7 +2,7 @@ struct TrackedReal{T<:Real} <: Real tracker::Tracked{T} end -TrackedReal(x::Real) = TrackedReal(Tracked(Call(nothing), x, zero(x))) +TrackedReal(x::Real) = TrackedReal(Tracked(Call(), x, zero(x))) tracker(x::TrackedReal) = x.tracker @@ -47,23 +47,21 @@ using DiffRules, SpecialFunctions, NaNMath for (M, f, arity) in DiffRules.diffrules() arity == 1 || continue @eval begin + @grad $M.$f(a::Real) = + $M.$f(a), Δ -> (Δ * $(DiffRules.diffrule(M, f, :(data(a)))),) $M.$f(a::TrackedReal) = track($M.$f, a) - back(::typeof($M.$f), Δ::Real, a::TrackedReal) = - back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a))))) end end for (M, f, arity) in DiffRules.diffrules() arity == 2 || continue da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b))) + f = :($M.$f) @eval begin - $M.$f(a::TrackedReal, b::TrackedReal) = track($M.$f, a, b) - $M.$f(a::TrackedReal, b::Real) = track($M.$f, a, b) - $M.$f(a::Real, b::TrackedReal) = track($M.$f, a, b) - function back(::typeof($M.$f), Δ::Real, a::Real, b::Real) - @back(a, Δ * $da) - @back(b, Δ * $db) - end + @grad $f(a::Real, b::Real) = $f(a, b), Δ -> (Δ * $da, Δ * $db) + $f(a::TrackedReal, b::TrackedReal) = track($f, a, b) + $f(a::TrackedReal, b::Real) = track($f, a, b) + $f(a::Real, b::TrackedReal) = track($f, a, b) end end From 5e319c739555a7a09447dc1aa44993eff85374e0 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 9 Jul 2018 13:39:10 +0100 Subject: [PATCH 2/9] fix gradient definitions --- src/cuda/cudnn.jl | 58 ++++----- src/tracker/Tracker.jl | 11 +- src/tracker/array.jl | 271 ++++++++++++++++++----------------------- src/tracker/back.jl | 9 +- src/tracker/scalar.jl | 6 +- test/tracker.jl | 9 ++ 6 files changed, 167 insertions(+), 197 deletions(-) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index bcadcf4f..28a9eec3 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -340,33 +340,33 @@ function accum_transpose!(dst::CuArray, src::CuArray) return dst end -function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h) - y, ho = y_ - dy, dho = Δ - h_ = hBatch(x, data(h)) - dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve) - @back(x, dx) - @back(h, unbroadcast(h, dh)) - (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve) - # We don't have to make this assumption, it's just slightly more complex. - @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) - istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) - istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh) - istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db) -end +# function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h) +# y, ho = y_ +# dy, dho = Δ +# h_ = hBatch(x, data(h)) +# dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve) +# @back(x, dx) +# @back(h, unbroadcast(h, dh)) +# (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve) +# # We don't have to make this assumption, it's just slightly more complex. +# @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) +# istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) +# istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh) +# istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db) +# end -function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c) - y, ho, co = y_ - dy, dho, dco = Δ - h_ = hBatch(x, data(h)) - c_ = hBatch(x, data(c)) - dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve) - @back(x, dx) - @back(h, unbroadcast(h, dh)) - @back(c, unbroadcast(h, dc)) - (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve) - @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) - istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) - istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh) - istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db) -end +# function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c) +# y, ho, co = y_ +# dy, dho, dco = Δ +# h_ = hBatch(x, data(h)) +# c_ = hBatch(x, data(c)) +# dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve) +# @back(x, dx) +# @back(h, unbroadcast(h, dh)) +# @back(c, unbroadcast(h, dc)) +# (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve) +# @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) +# istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) +# istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh) +# istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db) +# end diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 959fc8f1..4a58df29 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -1,6 +1,7 @@ module Tracker using MacroTools +using MacroTools: @q import Base: == @@ -51,8 +52,8 @@ track(f::Call) = track(f, f()) function _forward end -function track(f, xs...) - y, back = _forward(f, data.(xs)...) +function track(f, xs...; kw...) + y, back = _forward(f, data.(xs)...; kw...) track(Call(back, xs), y) end @@ -60,8 +61,8 @@ macro grad(ex) @capture(shortdef(ex), (name_(args__) = body_) | (name_(args__) where {T__} = body_)) || error("Need a function definition") T == nothing && (T = []) - unshift!(args, :(::typeof($name))) - :(Tracker._forward($(args...)) where $(T...) = $body) |> esc + insert!(args, 1+isexpr(args[1], :parameters) , :(::typeof($name))) + @q(Tracker._forward($(args...)) where $(T...) = $body) |> esc end function update!(x, Δ) @@ -83,7 +84,7 @@ Hook into gradient backpropagation. `x` is unmodified, but when backpropagating the sign of the gradient applied to `x`. """ hook(f, x) = istracked(x) ? track(hook, f, x) : x -back(::typeof(hook), Δ, f, x) = @back(x, f(Δ)) +@grad hook(f, x) = x, Δ -> (nothing, f(Δ)) param(x::Number) = TrackedReal(float(x)) param(xs::AbstractArray) = TrackedArray(float.(xs)) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 987630c7..709f0136 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -62,45 +62,47 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y) Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) -function back(::typeof(getindex), Δ, xs::TrackedArray, i...) - Δ′ = zeros(xs.data) - Δ′[i...] = Δ - @back(xs, Δ′) +@grad function getindex(xs, i...) + data(xs)[i...], function (Δ) + Δ′ = zeros(xs) + Δ′[i...] = Δ + (Δ′, map(_->nothing, i)...) + end end Base.:-(xs::TrackedArray) = track(-, xs) -back(::typeof(-), Δ, xs::TrackedArray) = back(xs, -Δ) +@grad -(xs) = -xs, Δ -> (-Δ,) Base.transpose(xs::TrackedArray) = track(transpose, xs) Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs) -back(::typeof(transpose), Δ, xs) = @back(xs, trim(xs, Δ.')) -back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ')) +@grad transpose(xs) = xs.', Δ -> (trim(xs, Δ.'),) +@grad ctranspose(xs) = xs', Δ -> (trim(xs, Δ'),) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) -function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) - Δ′ = similar(xs.data) - S = size(xs.data) +@grad function repmat(xs, m, n = 1) + repmat(xs, m, n), function (Δ) + Δ′ = similar(xs) + S = size(xs) for (i,v) in enumerate(Δ) d1 = divrem(i-1, S[1]*m) x = d1[2] % S[1]+1 y = d1[1] % S[2]+1 Δ′[x, y] += v end - back(xs, Δ′) + return (Δ′, nothing, nothing) + end end +Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...) -_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer) -Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer) - -function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer) - Δ′ = similar(xs.data) - Δ′ .= 0 - S = size(xs.data) +@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) + repeat(xs, inner = inner, outer = outer), function (Δ) + Δ′ = zero(xs) + S = size(xs) # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ for (dest_idx, val) in enumerate(IndexCartesian(), Δ) @@ -109,7 +111,8 @@ function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer) src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)] Δ′[src_idx...] += val end - back(xs, Δ′) + (Δ′,) + end end @@ -138,42 +141,51 @@ for f in [:vcat, :hcat] end end -function back(::typeof(vcat), Δ, xs...) - start = 0 - for xsi in xs - i = map(_ -> :, size(xsi)) |> Base.tail - @back(xsi, Δ[start+1:start+size(xsi,1), i...]) - start += size(xsi, 1) +@grad function vcat(xs...) + vcat(xs...), function (Δ) + start = 0 + Δs = [begin + i = map(_ -> :, size(xsi)) |> Base.tail + d = Δ[start+1:start+size(xsi,1), i...] + start += size(xsi, 1) + d + end for xsi in xs] + return (Δs...,) end end -function back(::typeof(hcat), Δ, xs...) - start = 0 - for xsi in xs - if ndims(xsi) == 1 - @back(xsi, Δ[:, start+1]) - else - i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail - @back(xsi, Δ[:, start+1:start+size(xsi,2), i...]) - end - start += size(xsi, 2) +@grad function hcat(xs...) + hcat(xs...), function (Δ) + start = 0 + Δs = [begin + d = if ndims(xsi) == 1 + Δ[:, start+1] + else + i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail + Δ[:, start+1:start+size(xsi,2), i...] + end + start += size(xsi, 2) + d + end for xsi in xs] + return (Δs...,) end end Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...) Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...) -function back(::typeof(cat), Δ, dims, Xs...) - start = ntuple(i -> 0, Val{ndims(Δ)}) - for xs in Xs - dim_xs = 1:ndims(xs) - till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)}) - - xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)}) - - @back(xs, reshape(Δ[xs_in_Δ...],size(xs))) - - start = start .+ till_xs +@grad function cat(dims, Xs...) + cat(dims, Xs...), function (Δ) + start = ntuple(i -> 0, Val{ndims(Δ)}) + Δs = [begin + dim_xs = 1:ndims(xs) + till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)}) + xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)}) + d = reshape(Δ[xs_in_Δ...],size(xs)) + start = start .+ till_xs + d + end for xs in Xs] + return (nothing, Δs...,) end end @@ -181,11 +193,10 @@ Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims)) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims) -back(::typeof(reshape), Δ, xs::TrackedArray, _...) = - back(xs, reshape(Δ, size(xs))) +@grad reshape(xs, dims) = reshape(xs, dims), Δ -> (reshape(Δ, size(xs)),nothing) Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims) -back(::typeof(permutedims), Δ, xs::TrackedArray, dims) = back(xs, permutedims(Δ, invperm(dims))) +@grad permutedims(xs, dims) = permutedims(xs, dims), Δ -> (permutedims(Δ, invperm(dims)),nothing) function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix) m1, n1 = size(mat1) @@ -207,14 +218,16 @@ Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim) Base.sum(xs::TrackedArray) = track(sum, xs) Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs)) -back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ) +@grad sum(xs, dim...) = sum(xs, dim...), + Δ -> (similar(xs) .= Δ, map(_->nothing,dim)...) Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim) Base.prod(xs::TrackedArray) = track(prod, xs) Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs)) -back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ) -back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ) +@grad prod(xs) = prod(xs), Δ -> (similar(xs) .= (prod(xs) ./ xs) .* Δ,) +@grad prod(xs, dim) = prod(xs, dim), + Δ -> (similar(xs) .= (reshape(.*(circshift.([reshape(xs, length(xs))], 1:length(xs)-1)...), size(xs))) .* Δ,nothing) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) @@ -230,10 +243,7 @@ LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys) -function back(::typeof(dot), Δ, xs, ys) - @back(xs, Δ.*data(ys)) - @back(ys, Δ.*data(xs)) -end +@grad dot(xs, ys) = dot(xs, ys), Δ -> (Δ .* ys, Δ .* xs) # Hacks to get std working Base.std(x::TrackedArray; mean = Base.mean(x)) = @@ -244,39 +254,30 @@ Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) = Base.vecnorm(x::TrackedArray, p::Real = 2) = sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0 -back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data)) -back(::typeof(mean), Δ, xs::TrackedArray, region) = - back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...))) +@grad mean(xs) = mean(xs), Δ -> (similar(xs) .= Δ ./ length(xs),) +@grad mean(xs, region) = mean(xs, region), Δ -> (similar(xs) .= Δ ./ prod(size(xs, region...)),nothing) -function back(::typeof(maximum), Δ, xs::TrackedArray) - Δ′ = zeros(xs.data) - _, i = findmax(xs.data) +@grad function maximum(xs, r...) + maximum(xs, r...), function (Δ) + Δ′ = zeros(xs) + _, i = findmax(xs, r...) Δ′[i] = Δ - @back(xs, Δ′) + return (Δ′,map(_->nothing,r)...) + end end -function back(::typeof(maximum), Δ, xs::TrackedArray, region) - Δ′ = zeros(xs.data) - _, is = findmax(xs.data, region) - Δ′[is] = Δ - @back(xs, Δ′) -end -function back(::typeof(minimum), Δ, xs::TrackedArray) - Δ′ = zeros(xs.data) - _, i = findmin(xs.data) +@grad function minimum(xs, r...) + minimum(xs, r...), function (Δ) + Δ′ = zeros(xs) + _, i = findmin(xs, r...) Δ′[i] = Δ - @back(xs, Δ′) -end -function back(::typeof(minimum), Δ, xs::TrackedArray, region) - Δ′ = zeros(xs.data) - _, is = findmin(xs.data, region) - Δ′[is] = Δ - @back(xs, Δ′) + return (Δ′,map(_->nothing,r)...) + end end # BLAS Base.diagm(x::TrackedVector) = track(diagm, x) -back(::typeof(diagm), Δ, x) = @back(x, diag(Δ)) +@grad diagm(x) = diagm(x), Δ -> (diag(Δ),) for f in :[*, Ac_mul_B, A_mul_Bc].args @eval begin @@ -295,30 +296,11 @@ for f in :[*, Ac_mul_B, A_mul_Bc].args end end -function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat) - @back(a, A_mul_Bt(Δ, data(b))) - @back(b, At_mul_B(data(a), Δ)) -end +@grad a::AbstractMatrix * b::AbstractVecOrMat = + a*b, Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ)) -function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) - @back(a, A_mul_Bt(Δ, data(b))') - @back(b, data(a)*Δ) -end - -function back(::typeof(A_mul_Bc), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) - @back(a, Δ * data(b)) - @back(b, At_mul_B(data(a), Δ)') -end - -# Fast path for matrix-vector -function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector) - if isleaf(W) - W.grad .+= Δ .* data(x).' - else - back(W, A_mul_Bt(Δ, data(x))) - end - @back(x, At_mul_B(data(W), Δ)) -end +@grad Ac_mul_B(a, b) = Ac_mul_B(a, b), Δ -> (A_mul_Bt(Δ, b)', a*Δ) +@grad A_mul_Bc(a, b) = A_mul_Bc(a, b), Δ -> (Δ * b, At_mul_B(a, Δ)') # NNlib @@ -327,65 +309,42 @@ import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, mea softmax(xs::TrackedArray) = track(softmax, xs) -back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs))) +@grad softmax(xs) = softmax(xs), Δ -> (∇softmax(Δ, xs),) logsoftmax(xs::TrackedArray) = track(logsoftmax, xs) -back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs))) +@grad logsoftmax(xs) = logsoftmax(xs), Δ -> (∇logsoftmax(Δ, xs),) -# TODO: can store kwargs efficiently in namedtuples -_conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation) +conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) +conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) +conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) -conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N = - track(_conv, x, w, stride, pad, dilation) -conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N = - track(_conv, x, w, stride, pad, dilation) -conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N = - track(_conv, x, w, stride, pad, dilation) +@grad conv(x, w; kw...) = + conv(x, w; kw...), + Δ -> (NNlib.∇conv_data(Δ, x, w; kw...), + NNlib.∇conv_filter(Δ, x, w; kw...)) -function back(::typeof(_conv), Δ, x, w, stride, pad, dilation) - @back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation)) - @back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation)) +maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...) + +@grad function maxpool(x, k; kw...) + y = maxpool(x, k; kw...) + y, Δ -> (NNlib.∇maxpool(Δ, y, x, k; kw...), nothing) end -_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride) +meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...) -maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) = - track(_maxpool, x, k, pad, stride) - -back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) = - back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride)) - -_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride) - -meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) = - track(_meanpool, x, k, pad, stride) - -back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) = - back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride)) +@grad function meanpool(x, k; kw...) + y = meanpool(x, k; kw...) + y, Δ -> (NNlib.∇meanpool(Δ, y, x, k; kw...), nothing) +end # Broadcasting -using ForwardDiff: Dual, partials - -struct Broadcasted{F,T} - f::F - data::T -end - -(b::Broadcasted)(xs...) = map(x -> x.value, b.data) +using ForwardDiff: Dual, partials, value dualify(xs, n) = xs -dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs)) -dualify(xs::TrackedReal, ps) = Dual(data(xs), ps) - -function tracked_broadcast(f, args::Vararg{Any,N}) where N - dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N})) - out = broadcast(f, dargs...) - eltype(out) <: Dual || return out - b = Broadcasted(f, out) - track(Call(b, args...), b()) -end +dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs) +dualify(xs::Real, ps) = Dual(xs, ps) trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)})) @@ -400,9 +359,17 @@ function getpartial(Δ, x, i) return Δ * p end -function back(b::Broadcasted, Δ, args::Vararg{Any,N}) where N - Δargs = ntuple(i -> getpartial.(Δ, b.data, i), Val{N}) - foreach((x, Δ) -> @back(x, unbroadcast(x, Δ)), args, Δargs) +function ∇broadcast(f, args::Vararg{Any,N}) where N + dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N})) + out = broadcast(f, dargs...) + eltype(out) <: Dual || return out + y = value.(out) + back = function (Δ) + Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N}) + map((x, Δ) -> unbroadcast(x, Δ), args, Δargs) + end + # So we can return non-tracked arrays + track(Call(back, args), y) end Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray @@ -415,4 +382,4 @@ Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = () Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A) -Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = tracked_broadcast(f, A, Bs...) +Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = ∇broadcast(f, A, Bs...) diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 5bf13d56..3d769778 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -23,7 +23,7 @@ end function back_(c::Call, Δ) Δs = c.func(Δ) - (Δs isa Tuple && length(Δs) == length(c.args)) || + (Δs isa Tuple && length(Δs) >= length(c.args)) || error("Gradient is not a tuple of length $(length(c.args))") foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args, Δs) end @@ -48,13 +48,6 @@ end back(x, Δ) = back(tracker(x), Δ) back(x::Void, Δ) = error("Can't backpropagate through `nothing`") -macro back(x, Δ) - quote - x = $(esc(x)) - istracked(x) && back(x, $(esc(Δ))) - end -end - # Interface methods # TODO: if an error occurs in `back` the refcounts will be broken diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 9d2d724f..93f9f7dc 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -100,13 +100,13 @@ back(::typeof(getindex), Δ, t, i) = function collect(xs) xs = Base.collect(xs) - track(Call(collect, xs), data.(xs)) + track(Call(collect, (xs,)), data.(xs)) end function scan(c::Call{typeof(collect)}) foreach(scan, c.args[1]) end -function back(::typeof(collect), Δ, xs) - foreach((x, Δ) -> @back(x, Δ), xs, Δ) +function back_(c::Call{typeof(collect)}, Δ) + foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args[1], Δ) end diff --git a/test/tracker.jl b/test/tracker.jl index 66c08f62..f1e704ad 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -111,6 +111,7 @@ end @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) +# TODO unreliable @test gradtest(x -> repmat(x, 5,5), rand(4,5)) @test gradtest(x -> repmat(x, 5), rand(4,5)) @@ -232,4 +233,12 @@ Tracker.back!(b) @test grad.((x,y)) == (3, 2) end +# Gradient Hooks +@testset "Hooks" begin + x = param(2) + y = Tracker.hook(-, x) + back!(y) + @test grad(x) == -1 +end + end #testset From 7778d1788434b4a2a2fb0286a19cf5ee1e73d7ac Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 9 Jul 2018 16:57:44 +0100 Subject: [PATCH 3/9] functional API --- REQUIRE | 1 - src/tracker/Tracker.jl | 3 +- src/tracker/back.jl | 87 ++++++++++++++++++++++++++++++++++++++++++ src/tracker/idset.jl | 25 ++++++++++++ src/tracker/numeric.jl | 6 --- src/treelike.jl | 5 +-- 6 files changed, 116 insertions(+), 11 deletions(-) create mode 100644 src/tracker/idset.jl diff --git a/REQUIRE b/REQUIRE index 8bb92ddb..95fda02c 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1,5 +1,4 @@ julia 0.6.0 -DataFlow 0.2.1 Juno MacroTools 0.3.3 NNlib diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 4a58df29..d5f7dcfb 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -1,7 +1,7 @@ module Tracker using MacroTools -using MacroTools: @q +using MacroTools: @q, @forward import Base: == @@ -71,6 +71,7 @@ function update!(x, Δ) return x end +include("idset.jl") include("back.jl") include("scalar.jl") include("array.jl") diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 3d769778..62cae1d0 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -52,6 +52,8 @@ back(x::Void, Δ) = error("Can't backpropagate through `nothing`") # TODO: if an error occurs in `back` the refcounts will be broken # and `back` will silently fail to update. +# Refcounts are also probably not safe in some situations (e.g. back called +# from within a backpropagator) function back!(x::Tracked, Δ) scan(x) @@ -59,3 +61,88 @@ function back!(x::Tracked, Δ) end back!(x, Δ) = back!(tracker(x), Δ) + +# Out-of-place gradients + +struct Params + params::IdSet + Params(xs) = new(IdSet(xs)) +end + +@forward Params.params Base.start, Base.next, Base.done + +struct Grads + grads::ObjectIdDict +end + +Grads() = Grads(ObjectIdDict()) + +Base.getindex(g::Grads, x::Tracked) = g.grads[x] +function Base.getindex(g::Grads, x) + istracked(x) || error("Object not tracked: $x") + g[tracker(x)] +end + +@forward Grads.grads Base.setindex!, Base.haskey + +accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] + Δ : Δ + +function back_(g::Grads, c::Call, Δ) + Δs = c.func(Δ) + (Δs isa Tuple && length(Δs) >= length(c.args)) || + error("Gradient is not a tuple of length $(length(c.args))") + foreach((x, Δ) -> istracked(x) && back(g, x, Δ), c.args, Δs) +end + +back_(g::Grads, ::Call{Void}, Δ) = nothing + +function back(g::Grads, x::Tracked, Δ) + x.isleaf && (accum!(g, x, Δ); return) + ref = x.ref -= 1 + if ref > 0 || haskey(g, x) + accum!(g, x, Δ) + ref == 0 && back_(g, x.f, g[x]) + else + ref == 0 && back_(g, x.f, Δ) + end + return +end + +back(g::Grads, x, Δ) = back(g, tracker(x), Δ) +back(g::Grads, x::Void, Δ) = error("Can't backpropagate through `nothing`") + +function forward(f, ps::Params) + y = f() + y, function (Δ) + g = Grads() + if istracked(y) + scan(y) + back(g, y, Δ) + end + for p in ps + haskey(g, tracker(p)) || + (g[tracker(p)] = init_grad(data(p))) + end + return g + end +end + +function forward(f, args...) + args = param.(args) + y, back = forward(() -> f(args...), Params(args)) + y, Δ -> getindex.(back(Δ), args) +end + +function losscheck(x) + x isa Real || error("Function output is not scalar") + isinf(x) && error("Loss is infinite") + isnan(x) && error("Loss is NaN") +end + +function gradient(f, args...) + y, back = forward(f, args...) + losscheck(y) + return back(1) +end + +derivative(f, x) = gradient(f, x)[1] diff --git a/src/tracker/idset.jl b/src/tracker/idset.jl new file mode 100644 index 00000000..68d1eea1 --- /dev/null +++ b/src/tracker/idset.jl @@ -0,0 +1,25 @@ +struct IdSet{T} <: AbstractSet{T} + dict::ObjectIdDict + IdSet{T}() where T = new(ObjectIdDict()) +end + +Base.eltype{T}(::IdSet{T}) = T + +IdSet() = IdSet{Any}() + +Base.push!{T}(s::IdSet{T}, x::T) = (s.dict[x] = nothing; s) +Base.delete!{T}(s::IdSet{T}, x::T) = (delete!(s.dict, x); s) +Base.in(x, s::IdSet) = haskey(s.dict, x) + +(::Type{IdSet{T}}){T}(xs) = push!(IdSet{T}(), xs...) + +IdSet(xs) = IdSet{eltype(xs)}(xs) + +Base.collect(s::IdSet) = Base.collect(keys(s.dict)) +Base.similar(s::IdSet, T::Type) = IdSet{T}() + +@forward IdSet.dict Base.length + +Base.start(s::IdSet) = start(keys(s.dict)) +Base.next(s::IdSet, st) = next(keys(s.dict), st) +Base.done(s::IdSet, st) = done(keys(s.dict), st) diff --git a/src/tracker/numeric.jl b/src/tracker/numeric.jl index 755e1f7d..e0028b7c 100644 --- a/src/tracker/numeric.jl +++ b/src/tracker/numeric.jl @@ -1,9 +1,3 @@ -function gradient(f, xs...) - xs = param.(xs) - back!(f(xs...)) - grad.(xs) -end - function ngradient(f, xs::AbstractArray...) grads = zeros.(xs) for (x, Δ) in zip(xs, grads), i in 1:length(x) diff --git a/src/treelike.jl b/src/treelike.jl index fbe9fcad..13e562e6 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -1,4 +1,5 @@ import Adapt: adapt +import .Tracker: IdSet children(x) = () mapchildren(f, x) = x @@ -20,9 +21,7 @@ function mapleaves(f, x; cache = ObjectIdDict()) cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x) end -using DataFlow: OSet - -function prefor(f, x; seen = OSet()) +function prefor(f, x; seen = IdSet()) x ∈ seen && return f(x) foreach(x -> prefor(f, x, seen = seen), children(x)) From 1430053b69c28c2a55b4138b93b7db5772cff41a Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 9 Jul 2018 17:52:34 +0100 Subject: [PATCH 4/9] checkpoints --- src/tracker/Tracker.jl | 16 ++++++++++++++++ test/tracker.jl | 14 +++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index d5f7dcfb..fd94bdb1 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -87,6 +87,22 @@ the sign of the gradient applied to `x`. hook(f, x) = istracked(x) ? track(hook, f, x) : x @grad hook(f, x) = x, Δ -> (nothing, f(Δ)) +""" + checkpoint(f, args...) + +Behaves like `f(args...)`, but avoids storing the intermediate values needed for +calculating gradients. Instead, `f(args...)` will be called again during the +backward pass. This can be used to save memory in larger models. +""" +checkpoint(f, args...) = track(checkpoint, f, args...) + +@grad function checkpoint(f, args...) + data(f(args...)), function (Δ) + y, back = forward(f, args...) + (nothing, back(Δ)...) + end +end + param(x::Number) = TrackedReal(float(x)) param(xs::AbstractArray) = TrackedArray(float.(xs)) diff --git a/test/tracker.jl b/test/tracker.jl index f1e704ad..40229e18 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -1,5 +1,5 @@ using Flux.Tracker, Base.Test, NNlib -using Flux.Tracker: TrackedReal, gradcheck, grad +using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint using NNlib: conv gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...) @@ -241,4 +241,16 @@ end @test grad(x) == -1 end +@testset "Checkpointing" begin + count = 0 + function mul(a, b) + count += 1 + a * b + end + @test derivative(x -> mul(5, x), 3) == 5 + @test count == 1 + @test derivative(x -> checkpoint(mul, 5, x), 3) == 5 + @test count == 3 +end + end #testset From e763c342ee5e78485371ceaea07d41d28cbf6664 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 9 Jul 2018 19:44:14 +0100 Subject: [PATCH 5/9] shave some memory --- src/tracker/Tracker.jl | 18 ++++++------------ src/tracker/array.jl | 7 ++++--- src/tracker/back.jl | 30 +++++++++++++++--------------- src/tracker/scalar.jl | 13 ++++++++----- 4 files changed, 33 insertions(+), 35 deletions(-) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index fd94bdb1..5afe9ced 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -11,9 +11,9 @@ tracker(x) = nothing istracked(x) = tracker(x) ≠ nothing isleaf(x) = !istracked(x) || isleaf(tracker(x)) -data(x) = istracked(x) ? data(tracker(x)) : x grad(x) = grad(tracker(x)) grad(::Void) = nothing +data(x) = x struct Call{F,As<:Tuple} func::F @@ -32,29 +32,23 @@ mutable struct Tracked{T} ref::UInt32 f::Call isleaf::Bool - data::T grad::T - Tracked{T}(f::Call, data::T) where T = new(0, f, false, data) - Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad) - Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad) + Tracked{T}(f::Call) where T = new(0, f, false) + Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad) + Tracked{T}(f::Call{Void}, grad::T) where T = new(0, f, true, grad) end -Tracked(f::Call, x) = Tracked{typeof(x)}(f, x) -Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ) - istracked(x::Tracked) = true isleaf(x::Tracked) = x.f == Call() -data(x::Tracked) = x.data grad(x::Tracked) = x.grad -track(f::Call, x) = Tracked(f, x) -track(f::Call) = track(f, f()) +track(f::Call, x) = Tracked{typeof(x)}(f) function _forward end function track(f, xs...; kw...) y, back = _forward(f, data.(xs)...; kw...) - track(Call(back, xs), y) + track(Call(back, tracker.(xs)), y) end macro grad(ex) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 709f0136..dbf789ac 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -6,6 +6,7 @@ struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N} TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad) end +data(x::TrackedArray) = x.data tracker(x::TrackedArray) = x.tracker TrackedVector{T,A} = TrackedArray{T,1,A} @@ -15,10 +16,10 @@ TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}} track(c::Call, x::AbstractArray) = TrackedArray(c, x) TrackedArray(c::Call, x::A) where A <: AbstractArray = - TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x) + TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x) TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = - TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ) + TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ) TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x)) @@ -369,7 +370,7 @@ function ∇broadcast(f, args::Vararg{Any,N}) where N map((x, Δ) -> unbroadcast(x, Δ), args, Δargs) end # So we can return non-tracked arrays - track(Call(back, args), y) + track(Call(back, tracker.(args)), y) end Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 62cae1d0..c6d1646a 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -10,8 +10,6 @@ function scan(x::Tracked) if ref == 1 scan(x.f) isdefined(x, :grad) && (x.grad = zero_grad!(x.grad)) - else - isdefined(x, :grad) || (x.grad = init_grad(x.data)) end return end @@ -25,7 +23,7 @@ function back_(c::Call, Δ) Δs = c.func(Δ) (Δs isa Tuple && length(Δs) >= length(c.args)) || error("Gradient is not a tuple of length $(length(c.args))") - foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args, Δs) + foreach(back, c.args, Δs) end back_(::Call{Void}, Δ) = nothing @@ -36,8 +34,12 @@ accum!(x::AbstractArray, Δ) = (x .+= Δ) function back(x::Tracked, Δ) x.isleaf && (x.grad = accum!(x.grad, Δ); return) ref = x.ref -= 1 - if isdefined(x, :grad) - x.grad = accum!(x.grad, Δ) + if ref > 0 || isdefined(x, :grad) + if isdefined(x, :grad) + x.grad = accum!(x.grad, Δ) + else + x.grad = Δ + end ref == 0 && back_(x.f, x.grad) else ref == 0 && back_(x.f, Δ) @@ -45,8 +47,7 @@ function back(x::Tracked, Δ) return end -back(x, Δ) = back(tracker(x), Δ) -back(x::Void, Δ) = error("Can't backpropagate through `nothing`") +back(::Void, _) = return # Interface methods @@ -55,13 +56,13 @@ back(x::Void, Δ) = error("Can't backpropagate through `nothing`") # Refcounts are also probably not safe in some situations (e.g. back called # from within a backpropagator) -function back!(x::Tracked, Δ) +function back!(x, Δ) + istracked(x) || return scan(x) - back(x, Δ) + back(tracker(x), Δ) + return end -back!(x, Δ) = back!(tracker(x), Δ) - # Out-of-place gradients struct Params @@ -91,7 +92,7 @@ function back_(g::Grads, c::Call, Δ) Δs = c.func(Δ) (Δs isa Tuple && length(Δs) >= length(c.args)) || error("Gradient is not a tuple of length $(length(c.args))") - foreach((x, Δ) -> istracked(x) && back(g, x, Δ), c.args, Δs) + foreach((x, Δ) -> back(g, x, Δ), c.args, Δs) end back_(g::Grads, ::Call{Void}, Δ) = nothing @@ -108,8 +109,7 @@ function back(g::Grads, x::Tracked, Δ) return end -back(g::Grads, x, Δ) = back(g, tracker(x), Δ) -back(g::Grads, x::Void, Δ) = error("Can't backpropagate through `nothing`") +back(::Grads, ::Void, _) = return function forward(f, ps::Params) y = f() @@ -117,7 +117,7 @@ function forward(f, ps::Params) g = Grads() if istracked(y) scan(y) - back(g, y, Δ) + back(g, tracker(y), Δ) end for p in ps haskey(g, tracker(p)) || diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 93f9f7dc..6232807f 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -1,12 +1,14 @@ struct TrackedReal{T<:Real} <: Real + data::T tracker::Tracked{T} end -TrackedReal(x::Real) = TrackedReal(Tracked(Call(), x, zero(x))) +TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x))) +data(x::TrackedReal) = x.data tracker(x::TrackedReal) = x.tracker -track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x))) +track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x))) function back!(x::TrackedReal) isinf(x) && error("Loss is Inf") @@ -73,6 +75,7 @@ import Base:^ # Tuples struct TrackedTuple{T<:Tuple} + data::T tracker::Tracked{T} end @@ -82,7 +85,7 @@ accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ) init_grad(x::Tuple) = init_grad.(x) zero_grad!(x::Tuple) = zero_grad!.(x) -track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs)) +track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f)) function Base.show(io::IO, xs::TrackedTuple) show(io, data(xs)) @@ -100,7 +103,7 @@ back(::typeof(getindex), Δ, t, i) = function collect(xs) xs = Base.collect(xs) - track(Call(collect, (xs,)), data.(xs)) + track(Call(collect, (tracker.(xs),)), data.(xs)) end function scan(c::Call{typeof(collect)}) @@ -108,5 +111,5 @@ function scan(c::Call{typeof(collect)}) end function back_(c::Call{typeof(collect)}, Δ) - foreach((x, Δ) -> istracked(x) && back(x, Δ), c.args[1], Δ) + foreach(back, c.args[1], Δ) end From 80af9a3830afbda23b5c8c3b1cc18f129220b23d Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 9 Jul 2018 23:40:07 +0100 Subject: [PATCH 6/9] broadcast efficiency --- src/tracker/array.jl | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index dbf789ac..ee8c9b39 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -78,8 +78,8 @@ Base.:-(xs::TrackedArray) = track(-, xs) Base.transpose(xs::TrackedArray) = track(transpose, xs) Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs) -@grad transpose(xs) = xs.', Δ -> (trim(xs, Δ.'),) -@grad ctranspose(xs) = xs', Δ -> (trim(xs, Δ'),) +@grad transpose(xs) = xs.', Δ -> (reshape(Δ.', size(xs)),) +@grad ctranspose(xs) = xs', Δ -> (reshape(Δ', size(xs)),) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) @@ -347,13 +347,11 @@ dualify(xs, n) = xs dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs) dualify(xs::Real, ps) = Dual(xs, ps) -trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)})) +unbroadcast(x::Tuple, Δ) = + x == size(Δ) ? Δ : + reshape(sum(Δ, filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x) -unbroadcast(x::AbstractArray, Δ) = - size(x) == size(Δ) ? Δ : - trim(x, sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ)))) - -unbroadcast(x::Number, Δ) = sum(Δ) +unbroadcast(x::Tuple{}, Δ) = sum(Δ) function getpartial(Δ, x, i) @inbounds p = getindex(partials(x), i) @@ -361,13 +359,14 @@ function getpartial(Δ, x, i) end function ∇broadcast(f, args::Vararg{Any,N}) where N + sizes = size.(args) dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N})) out = broadcast(f, dargs...) eltype(out) <: Dual || return out y = value.(out) back = function (Δ) Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N}) - map((x, Δ) -> unbroadcast(x, Δ), args, Δargs) + map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs) end # So we can return non-tracked arrays track(Call(back, tracker.(args)), y) From 70b5efeb4e83a1c9c317271e16181b8128b8ecfd Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 10 Jul 2018 09:03:09 +0100 Subject: [PATCH 7/9] basic nested AD --- src/tracker/Tracker.jl | 10 +++- src/tracker/array.jl | 113 ++++++++++++++++++++++------------------- src/tracker/back.jl | 10 ++-- src/tracker/numeric.jl | 2 +- src/tracker/scalar.jl | 8 +-- 5 files changed, 80 insertions(+), 63 deletions(-) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 5afe9ced..2a94edb7 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -47,7 +47,7 @@ track(f::Call, x) = Tracked{typeof(x)}(f) function _forward end function track(f, xs...; kw...) - y, back = _forward(f, data.(xs)...; kw...) + y, back = _forward(f, xs...; kw...) track(Call(back, tracker.(xs)), y) end @@ -97,9 +97,17 @@ checkpoint(f, args...) = track(checkpoint, f, args...) end end +nobacksies(f, x) = track(nobacksies, f, x) +nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs) +@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f") + param(x::Number) = TrackedReal(float(x)) param(xs::AbstractArray) = TrackedArray(float.(xs)) +@grad identity(x) = data(x), Δ -> (Δ,) +param(x::TrackedReal) = track(identity, x) +param(x::TrackedArray) = track(identity, x) + import NNlib.cudata import Adapt.adapt diff --git a/src/tracker/array.jl b/src/tracker/array.jl index ee8c9b39..e034a868 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -50,6 +50,9 @@ for f in :[Base.size, Base.ndims].args @eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...) end +Base.size(x::TrackedArray, i::Integer, j::Integer, is::Integer...) = + size(data(x), i, j, is...) + Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = similar(data(x), dims...) @@ -65,54 +68,54 @@ Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) @grad function getindex(xs, i...) data(xs)[i...], function (Δ) - Δ′ = zeros(xs) - Δ′[i...] = Δ - (Δ′, map(_->nothing, i)...) + Δ′ = zero(xs) + Δ′[i...] = data(Δ) + (nobacksies(:getindex, Δ′), map(_->nothing, i)...) end end Base.:-(xs::TrackedArray) = track(-, xs) -@grad -(xs) = -xs, Δ -> (-Δ,) +@grad -(xs) = -data(xs), Δ -> (-Δ,) Base.transpose(xs::TrackedArray) = track(transpose, xs) Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs) -@grad transpose(xs) = xs.', Δ -> (reshape(Δ.', size(xs)),) -@grad ctranspose(xs) = xs', Δ -> (reshape(Δ', size(xs)),) +@grad transpose(xs) = data(xs).', Δ -> (reshape(Δ.', size(xs)),) +@grad ctranspose(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) @grad function repmat(xs, m, n = 1) - repmat(xs, m, n), function (Δ) + repmat(data(xs), m, n), function (Δ) Δ′ = similar(xs) S = size(xs) - for (i,v) in enumerate(Δ) + for (i,v) in enumerate(data(Δ)) d1 = divrem(i-1, S[1]*m) x = d1[2] % S[1]+1 y = d1[1] % S[2]+1 Δ′[x, y] += v end - return (Δ′, nothing, nothing) + return (nobacksies(:repmat, Δ′), nothing, nothing) end end Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...) @grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) - repeat(xs, inner = inner, outer = outer), function (Δ) + repeat(data(xs), inner = inner, outer = outer), function (Δ) Δ′ = zero(xs) S = size(xs) # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ - for (dest_idx, val) in enumerate(IndexCartesian(), Δ) + for (dest_idx, val) in enumerate(IndexCartesian(), data(Δ)) # First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then # wrap around based on original size S. src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)] Δ′[src_idx...] += val end - (Δ′,) + (nobacksies(:repeat, Δ′),) end end @@ -143,7 +146,7 @@ for f in [:vcat, :hcat] end @grad function vcat(xs...) - vcat(xs...), function (Δ) + vcat(data.(xs)...), function (Δ) start = 0 Δs = [begin i = map(_ -> :, size(xsi)) |> Base.tail @@ -156,7 +159,7 @@ end end @grad function hcat(xs...) - hcat(xs...), function (Δ) + hcat(data.(xs)...), function (Δ) start = 0 Δs = [begin d = if ndims(xsi) == 1 @@ -176,7 +179,7 @@ Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...) Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...) @grad function cat(dims, Xs...) - cat(dims, Xs...), function (Δ) + cat(dims, data.(Xs)...), function (Δ) start = ntuple(i -> 0, Val{ndims(Δ)}) Δs = [begin dim_xs = 1:ndims(xs) @@ -194,10 +197,10 @@ Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims)) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims) -@grad reshape(xs, dims) = reshape(xs, dims), Δ -> (reshape(Δ, size(xs)),nothing) +@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing) Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims) -@grad permutedims(xs, dims) = permutedims(xs, dims), Δ -> (permutedims(Δ, invperm(dims)),nothing) +@grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing) function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix) m1, n1 = size(mat1) @@ -219,16 +222,18 @@ Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim) Base.sum(xs::TrackedArray) = track(sum, xs) Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs)) -@grad sum(xs, dim...) = sum(xs, dim...), - Δ -> (similar(xs) .= Δ, map(_->nothing,dim)...) +@grad sum(xs, dim...) = sum(data(xs), dim...), + Δ -> (zero(xs) .+ Δ, map(_->nothing,dim)...) Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim) Base.prod(xs::TrackedArray) = track(prod, xs) Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs)) -@grad prod(xs) = prod(xs), Δ -> (similar(xs) .= (prod(xs) ./ xs) .* Δ,) -@grad prod(xs, dim) = prod(xs, dim), - Δ -> (similar(xs) .= (reshape(.*(circshift.([reshape(xs, length(xs))], 1:length(xs)-1)...), size(xs))) .* Δ,nothing) +@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,) +@grad prod(xs, dim) = prod(data(xs), dim), + Δ -> (nobacksies(:sum, + reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ), + nothing) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) @@ -244,7 +249,7 @@ LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys) -@grad dot(xs, ys) = dot(xs, ys), Δ -> (Δ .* ys, Δ .* xs) +@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs) # Hacks to get std working Base.std(x::TrackedArray; mean = Base.mean(x)) = @@ -255,32 +260,32 @@ Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) = Base.vecnorm(x::TrackedArray, p::Real = 2) = sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0 -@grad mean(xs) = mean(xs), Δ -> (similar(xs) .= Δ ./ length(xs),) -@grad mean(xs, region) = mean(xs, region), Δ -> (similar(xs) .= Δ ./ prod(size(xs, region...)),nothing) +@grad mean(xs) = mean(data(xs)), Δ -> (Δ / length(xs),) +@grad mean(xs, region) = mean(data(xs), region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing) @grad function maximum(xs, r...) - maximum(xs, r...), function (Δ) - Δ′ = zeros(xs) - _, i = findmax(xs, r...) - Δ′[i] = Δ - return (Δ′,map(_->nothing,r)...) + maximum(data(xs), r...), function (Δ) + Δ′ = zero(xs) + _, i = findmax(data(xs), r...) + Δ′[i] = data(Δ) + return (nobacksies(:maximum, Δ′),map(_->nothing,r)...) end end @grad function minimum(xs, r...) - minimum(xs, r...), function (Δ) - Δ′ = zeros(xs) - _, i = findmin(xs, r...) - Δ′[i] = Δ - return (Δ′,map(_->nothing,r)...) + minimum(data(xs), r...), function (Δ) + Δ′ = zero(xs) + _, i = findmin(data(xs), r...) + Δ′[i] = data(Δ) + return (nobacksies(:minimum, Δ′),map(_->nothing,r)...) end end # BLAS Base.diagm(x::TrackedVector) = track(diagm, x) -@grad diagm(x) = diagm(x), Δ -> (diag(Δ),) +@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),) -for f in :[*, Ac_mul_B, A_mul_Bc].args +for f in :[*, Ac_mul_B, A_mul_Bc, A_mul_Bt, At_mul_B].args @eval begin import Base.$f $f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b) @@ -298,10 +303,13 @@ for f in :[*, Ac_mul_B, A_mul_Bc].args end @grad a::AbstractMatrix * b::AbstractVecOrMat = - a*b, Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ)) + data(a)*data(b), Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ)) -@grad Ac_mul_B(a, b) = Ac_mul_B(a, b), Δ -> (A_mul_Bt(Δ, b)', a*Δ) -@grad A_mul_Bc(a, b) = A_mul_Bc(a, b), Δ -> (Δ * b, At_mul_B(a, Δ)') +@grad Ac_mul_B(a, b) = Ac_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ) +@grad A_mul_Bc(a, b) = A_mul_Bc(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)') + +@grad At_mul_B(a, b) = At_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ) +@grad A_mul_Bt(a, b) = A_mul_Bt(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)') # NNlib @@ -310,33 +318,34 @@ import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, mea softmax(xs::TrackedArray) = track(softmax, xs) -@grad softmax(xs) = softmax(xs), Δ -> (∇softmax(Δ, xs),) +@grad softmax(xs) = softmax(data(xs)), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs))),) logsoftmax(xs::TrackedArray) = track(logsoftmax, xs) -@grad logsoftmax(xs) = logsoftmax(xs), Δ -> (∇logsoftmax(Δ, xs),) +@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),) conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) @grad conv(x, w; kw...) = - conv(x, w; kw...), - Δ -> (NNlib.∇conv_data(Δ, x, w; kw...), - NNlib.∇conv_filter(Δ, x, w; kw...)) + conv(data(x), data(w); kw...), + Δ -> nobacksies(:conv, + (NNlib.∇conv_data(data.((Δ, x, w))...; kw...), + NNlib.∇conv_filter(data.((Δ, x, w))...; kw...))) maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...) @grad function maxpool(x, k; kw...) - y = maxpool(x, k; kw...) - y, Δ -> (NNlib.∇maxpool(Δ, y, x, k; kw...), nothing) + y = maxpool(data(x), k; kw...) + y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing) end meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...) @grad function meanpool(x, k; kw...) - y = meanpool(x, k; kw...) - y, Δ -> (NNlib.∇meanpool(Δ, y, x, k; kw...), nothing) + y = meanpool(data(x), k; kw...) + y, Δ -> (nobacksies(:maxpool, NNlib.∇meanpool(data.((Δ, y, x))..., k; kw...)), nothing) end # Broadcasting @@ -364,9 +373,11 @@ function ∇broadcast(f, args::Vararg{Any,N}) where N out = broadcast(f, dargs...) eltype(out) <: Dual || return out y = value.(out) - back = function (Δ) + back = function (Δ_) + Δ = data(Δ_) Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N}) - map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs) + dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs) + nobacksies(:broadcast, dxs) end # So we can return non-tracked arrays track(Call(back, tracker.(args)), y) diff --git a/src/tracker/back.jl b/src/tracker/back.jl index c6d1646a..8e492861 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -23,7 +23,7 @@ function back_(c::Call, Δ) Δs = c.func(Δ) (Δs isa Tuple && length(Δs) >= length(c.args)) || error("Gradient is not a tuple of length $(length(c.args))") - foreach(back, c.args, Δs) + foreach(back, c.args, data.(Δs)) end back_(::Call{Void}, Δ) = nothing @@ -78,6 +78,8 @@ end Grads() = Grads(ObjectIdDict()) +Grads(ps::Params) = Grads(ObjectIdDict(tracker(p) => init_grad(data(p)) for p in ps)) + Base.getindex(g::Grads, x::Tracked) = g.grads[x] function Base.getindex(g::Grads, x) istracked(x) || error("Object not tracked: $x") @@ -114,15 +116,11 @@ back(::Grads, ::Void, _) = return function forward(f, ps::Params) y = f() y, function (Δ) - g = Grads() + g = Grads(ps) if istracked(y) scan(y) back(g, tracker(y), Δ) end - for p in ps - haskey(g, tracker(p)) || - (g[tracker(p)] = init_grad(data(p))) - end return g end end diff --git a/src/tracker/numeric.jl b/src/tracker/numeric.jl index e0028b7c..1ad872e4 100644 --- a/src/tracker/numeric.jl +++ b/src/tracker/numeric.jl @@ -15,4 +15,4 @@ end gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), - gradient(f, xs...), rtol = 1e-5, atol = 1e-5)) + data.(gradient(f, xs...)), rtol = 1e-5, atol = 1e-5)) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 6232807f..7e574fd9 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -50,17 +50,17 @@ for (M, f, arity) in DiffRules.diffrules() arity == 1 || continue @eval begin @grad $M.$f(a::Real) = - $M.$f(a), Δ -> (Δ * $(DiffRules.diffrule(M, f, :(data(a)))),) + $M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),) $M.$f(a::TrackedReal) = track($M.$f, a) end end for (M, f, arity) in DiffRules.diffrules() arity == 2 || continue - da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b))) + da, db = DiffRules.diffrule(M, f, :a, :b) f = :($M.$f) @eval begin - @grad $f(a::Real, b::Real) = $f(a, b), Δ -> (Δ * $da, Δ * $db) + @grad $f(a::Real, b::Real) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) $f(a::TrackedReal, b::TrackedReal) = track($f, a, b) $f(a::TrackedReal, b::Real) = track($f, a, b) $f(a::Real, b::TrackedReal) = track($f, a, b) @@ -111,5 +111,5 @@ function scan(c::Call{typeof(collect)}) end function back_(c::Call{typeof(collect)}, Δ) - foreach(back, c.args[1], Δ) + foreach(back, c.args[1], data(Δ)) end From 10a169bb779cbfa6315af1b0e799b5e1a8606901 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Tue, 10 Jul 2018 18:16:37 +0100 Subject: [PATCH 8/9] update cudnn rnn --- src/cuda/cudnn.jl | 84 ++++++++++++++---------------------------- src/tracker/Tracker.jl | 3 +- src/tracker/array.jl | 2 +- src/tracker/scalar.jl | 8 ++-- 4 files changed, 36 insertions(+), 61 deletions(-) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 28a9eec3..85b5b975 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -286,41 +286,28 @@ function desc(rnn) return d end -import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast - -mutable struct RNNCall{R} - rnn::R - reserve::CuVector{UInt8} - RNNCall{R}(rnn::R) where R = new(rnn) -end - -RNNCall(rnn) = RNNCall{typeof(rnn)}(rnn) - -function (c::RNNCall)(args...) - rs, result = forwardTrain(desc(c.rnn), args...) - c.reserve = rs - return result -end +import Flux.Tracker +import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...)) function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} result = istrain(m, h, x) ? - track(RNNCall(m), x, h) : + track(m, x, h, m.Wi, m.Wh, m.b) : forward(desc(m), x, h) return result[2], result[1] end function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} result = istrain(m, h, x) ? - track(RNNCall(m), x, h) : + track(m, x, h, m.Wi, m.Wh, m.b) : forward(desc(m), x, h) return result[2], result[1] end function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64} result = istrain(m, h, x) ? - track(RNNCall(m), x, h[1], h[2]) : + track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) : forward(desc(m), x, h[1], h[2]) return (result[2], result[3]), result[1] end @@ -329,44 +316,29 @@ end (m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -function accum_transpose!(dst::CuArray, src::CuArray) - function kernel(dst, src) - I = @cuindex dst - dst[I...] += src[reverse(I)...] - return +@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b) + reserve, result = forwardTrain(desc(m), data(x), data(h)) + result, function (Δ) + y, ho = result + dy, dho = Δ + h_ = hBatch(x, data(h)) + dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) + (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) + nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db)) end - blk, thr = cudims(dst) - @cuda (blk, thr) kernel(dst, src) - return dst end -# function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h) -# y, ho = y_ -# dy, dho = Δ -# h_ = hBatch(x, data(h)) -# dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve) -# @back(x, dx) -# @back(h, unbroadcast(h, dh)) -# (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve) -# # We don't have to make this assumption, it's just slightly more complex. -# @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) -# istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) -# istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh) -# istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db) -# end - -# function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c) -# y, ho, co = y_ -# dy, dho, dco = Δ -# h_ = hBatch(x, data(h)) -# c_ = hBatch(x, data(c)) -# dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve) -# @back(x, dx) -# @back(h, unbroadcast(h, dh)) -# @back(c, unbroadcast(h, dc)) -# (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve) -# @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) -# istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) -# istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh) -# istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db) -# end +@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b) + reserve, result = forwardTrain(desc(m), data.((x, h, c))...) + result, function (Δ) + y, ho = result + dy, dho, dco = Δ + h_ = hBatch(x, data(h)) + c_ = hBatch(x, data(c)) + dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) + (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) + nobacksies(:RNN, + (dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc), + dWi.', dWh.', db)) + end +end diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 2a94edb7..4cbde1f0 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -55,7 +55,8 @@ macro grad(ex) @capture(shortdef(ex), (name_(args__) = body_) | (name_(args__) where {T__} = body_)) || error("Need a function definition") T == nothing && (T = []) - insert!(args, 1+isexpr(args[1], :parameters) , :(::typeof($name))) + isexpr(name, :(::)) || (name = :(::typeof($name))) + insert!(args, 1+isexpr(args[1], :parameters) , name) @q(Tracker._forward($(args...)) where $(T...) = $body) |> esc end diff --git a/src/tracker/array.jl b/src/tracker/array.jl index e034a868..6c7f93e3 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -66,7 +66,7 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y) Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) -@grad function getindex(xs, i...) +@grad function getindex(xs::AbstractArray, i...) data(xs)[i...], function (Δ) Δ′ = zero(xs) Δ′[i...] = data(Δ) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 7e574fd9..50b9c7af 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -79,13 +79,14 @@ struct TrackedTuple{T<:Tuple} tracker::Tracked{T} end +data(xs::TrackedTuple) = xs.data tracker(xs::TrackedTuple) = xs.tracker accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ) init_grad(x::Tuple) = init_grad.(x) zero_grad!(x::Tuple) = zero_grad!.(x) -track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f)) +track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs))) function Base.show(io::IO, xs::TrackedTuple) show(io, data(xs)) @@ -96,8 +97,9 @@ Base.length(x::TrackedTuple) = length(data(x)) Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i) -back(::typeof(getindex), Δ, t, i) = - back(t, ntuple(j -> i == j ? Δ : 0, length(t))) +@grad function getindex(xs::TrackedTuple, i) + data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing) +end # Array collection From dda51a0140b8cae9eb3467d59cab0d685f67b571 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 11 Jul 2018 15:31:22 +0100 Subject: [PATCH 9/9] update docs --- docs/src/internals/tracker.md | 165 +++++++++++++++++--------------- docs/src/models/basics.md | 72 ++++++++++++-- docs/src/models/recurrence.md | 2 +- docs/src/training/optimisers.md | 8 +- src/tracker/Tracker.jl | 4 +- src/tracker/back.jl | 8 ++ 6 files changed, 169 insertions(+), 90 deletions(-) diff --git a/docs/src/internals/tracker.md b/docs/src/internals/tracker.md index b9addc34..2c134f12 100644 --- a/docs/src/internals/tracker.md +++ b/docs/src/internals/tracker.md @@ -6,6 +6,52 @@ Backpropagation, or reverse-mode automatic differentiation, is handled by the `F julia> using Flux.Tracker ``` +Here we discuss some more advanced uses of this module, as well as covering its internals. + +## Taking Gradients + +In the [basics section](../models/basics.md) we covered basic usage of the `gradient` function. + +```julia +using Flux.Tracker + +Tracker.gradient((a, b) -> a*b, 2, 3) # (3.0 (tracked), 2.0 (tracked)) +``` + +`gradient` is actually just a thin wrapper around the backpropagator-based interface, `forward`. + +```julia +using Flux.Tracker: forward + +y, back = forward((a, b) -> a*b, 2, 3) # (6.0 (tracked), Flux.Tracker.#9) + +back(1) # (3.0 (tracked), 2.0 (tracked)) +``` + +The `forward` function returns two results. The first, `y`, is the original value of the function (perhaps with tracking applied). The second, `back`, is a new function which, given a sensitivity, returns the sensitivity of the inputs to `forward` (we call this a "backpropagator"). One use of this interface is to provide custom sensitivities when outputs are not scalar. + +```julia +julia> y, back = forward((a, b) -> a.*b, [1,2,3],[4,5,6]) +(param([4.0, 10.0, 18.0]), Flux.Tracker.#9) + +julia> back([1,1,1]) +(param([4.0, 5.0, 6.0]), param([1.0, 2.0, 3.0])) +``` + +We can also take gradients in-place. This can be useful if you only care about first-order gradients. + +```julia +a, b = param(2), param(3) + +c = a*b # 6.0 (tracked) + +Tracker.back!(c) + +Tracker.grad(a), Tracker.grad(b) # (3.0, 2.0) +``` + +## Tracked Arrays + The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters: ```julia @@ -41,7 +87,48 @@ julia> x.grad -2.0 ``` -## Internals +You may sometimes want to drop derivative information and just get the plain value back. You can do this by calling `Tracker.data(W)`. + +## Custom Gradients + +We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`: + +```julia +minus(a, b) = a - b +``` + +Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch: + +```julia +using Flux.Tracker: TrackedReal, track, @grad + +minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b) +``` + +`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition. + +```julia +@grad function minus(a, b) + return minus(data(a),data(b)), Δ -> (Δ, -Δ) +end +``` + +This is essentially just a way of overloading the `forward` function we saw above. We strip tracking from `a` and `b` so that we are calling the original definition of `minus` (otherwise, we'd just try to track the call again and hit an infinite regress). + +Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to track this, since nest AD will take a derivative through the backpropagator itself. For example, the gradient of `*` might look like this. + +```julia +@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ) +``` + +For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed: + +```julia +minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b) +minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b) +``` + +## Tracked Internals All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field. @@ -50,14 +137,9 @@ julia> x.tracker Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0]) ``` -The `Tracker` stores the value and gradient of a given object, which we've seen before. +The `Tracker` stores the gradient of a given object, which we've seen before. ```julia -julia> x.tracker.data -2-element Array{Float64,1}: - 5.0 - 6.0 - julia> x.tracker.grad 2-element Array{Float64,1}: -2.0 @@ -86,71 +168,4 @@ When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forward Tracker.back(*, [1, -1], W, x) ``` -which in turn calculates the sensitivities of the arguments (`W` and `x`) and backpropagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters. - -## Custom Gradients - -We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`: - -```julia -julia> minus(a, b) = a - b -``` - -Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch: - -```julia -julia> minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b) -minus (generic function with 2 methods) -``` - -`Tracker.track` does two things: (1) it makes sure `minus` is called with *normal* array, not tracked ones (you can use `@show` inside `minus` to verify this), and (2) it uses the result to add a `minus` node to the tape. Look inside the result of calling `minus` to see what happened: - -```julia -julia> a, b = param([6,5,4]), param([1,2,3]) -(param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0])) - -julia> c = minus(a, b) -Tracked 3-element Array{Float64,1}: - 5.0 - 3.0 - 1.0 - -julia> c.tracker.f -Flux.Tracker.Call{...}(minus, (param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0]))) -``` - -Finally, we have to specify the gradient of `minus`. - -```julia -julia> Tracker.back(::typeof(minus), Δ, a, b) = - (Tracker.@back(a, Δ); Tracker.@back(b, -Δ)) -``` - -`@back(x, Δ)` tells the tracker to continue propagating the sensitivity `Δ` through `x`. Now, AD will work with any program that calls `minus`. - -```julia -julia> Flux.back!(c, 1) - -julia> a.grad -3-element Array{Float64,1}: - 1.0 - 1.0 - 1.0 - -julia> b.grad -3-element Array{Float64,1}: - -1.0 - -1.0 - -1.0 -``` - -## Notes - -For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed: - -```julia -minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b) -minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b) -``` - -`@back` *must* be called exactly once on each tracked input argument. You do not need to do any special handling if one of the arguments is not tracked, as `@back` will just become a no-op. +which in turn calculates the sensitivities of the arguments (`W` and `x`) and back-propagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters. diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index 96efc7b8..134e251b 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -2,20 +2,74 @@ ## Taking Gradients -Consider a simple linear regression, which tries to predict an output array `y` from an input `x`. (It's a good idea to follow this example in the Julia repl.) +Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.) + +```julia +using Flux.Tracker + +f(x) = 3x^2 + 2x + 1 + +# df/dx = 6x + 2 +f′(x) = Tracker.gradient(f, x)[1] + +f′(2) # 14.0 (tracked) + +# d²f/dx² = 6 +f′′(x) = Tracker.gradient(f′, x)[1] + +f′′(2) # 6.0 (tracked) +``` + +(We'll learn more about why these numbers show up as `(tracked)` below.) + +When a function has many parameters, we can pass them all in explicitly: + +```julia +f(W, b, x) = W * x + b + +Tracker.gradient(f, 2, 3, 4) +(4.0 (tracked), 1.0, 2.0 (tracked)) +``` + +But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all of them at once. + +```julia +W = param(2) # 2.0 (tracked) +b = param(3) # 3.0 (tracked) + +f(x) = W * x + b + +params = Params([W, b]) +grads = Tracker.gradient(() -> f(4), params) + +grads[W] # 4.0 +grads[b] # 1.0 +``` + +There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `Params` tell it what to differentiate. + +This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple. + +## Simple Models + +Consider a simple linear regression, which tries to predict an output array `y` from an input `x`. ```julia W = rand(2, 5) b = rand(2) predict(x) = W*x .+ b -loss(x, y) = sum((predict(x) .- y).^2) + +function loss(x, y) + ŷ = predict(x) + sum((y .- ŷ).^2) +end x, y = rand(5), rand(2) # Dummy data loss(x, y) # ~ 3 ``` -To improve the prediction we can take the gradients of `W` and `b` with respect to the loss function and perform gradient descent. We could calculate gradients by hand, but Flux will do it for us if we tell it that `W` and `b` are trainable *parameters*. +To improve the prediction we can take the gradients of `W` and `b` with respect to the loss and perform gradient descent. Let's tell Flux that `W` and `b` are parameters, just like we did above. ```julia using Flux.Tracker @@ -23,17 +77,15 @@ using Flux.Tracker W = param(W) b = param(b) -l = loss(x, y) - -back!(l) +gs = Tracker.gradient(() -> loss(x, y), Params([W, b])) ``` -`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then accumulates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model. +Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent. ```julia -using Flux.Tracker: grad, update! +using Flux.Tracker: update! -Δ = grad(W) +Δ = gs[W] # Update the parameter and reset the gradient update!(W, -0.1Δ) @@ -43,7 +95,7 @@ loss(x, y) # ~ 2.5 The loss has decreased a little, meaning that our prediction `x` is closer to the target `y`. If we have some data we can already try [training the model](../training/training.md). -All deep learning in Flux, however complex, is a simple generalisation of this example. Of course, models can *look* very different – they might have millions of parameters or complex control flow, and there are ways to manage this complexity. Let's see what that looks like. +All deep learning in Flux, however complex, is a simple generalisation of this example. Of course, models can *look* very different – they might have millions of parameters or complex control flow. Let's see how Flux handles more complex models. ## Building Layers diff --git a/docs/src/models/recurrence.md b/docs/src/models/recurrence.md index befe32dd..7c20165d 100644 --- a/docs/src/models/recurrence.md +++ b/docs/src/models/recurrence.md @@ -103,7 +103,7 @@ m.(seq) ## Truncating Gradients -By default, calculating the gradients in a recurrent layer involves the entire history. For example, if we call the model on 100 inputs, calling `back!` will calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients – this accumulates and quickly becomes expensive. +By default, calculating the gradients in a recurrent layer involves its entire history. For example, if we call the model on 100 inputs, we'll have to calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients – this accumulates and quickly becomes expensive. To avoid this we can *truncate* the gradient calculation, forgetting the history. diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index ac58f6d0..968622be 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -3,6 +3,8 @@ Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`. ```julia +using Flux.Tracker + W = param(rand(2, 5)) b = param(rand(2)) @@ -11,7 +13,9 @@ loss(x, y) = sum((predict(x) .- y).^2) x, y = rand(5), rand(2) # Dummy data l = loss(x, y) # ~ 3 -back!(l) + +params = Params([W, b]) +grads = Tracker.gradient(() -> loss(x, y), params) ``` We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that: @@ -22,7 +26,7 @@ using Flux.Tracker: grad, update! function sgd() η = 0.1 # Learning Rate for p in (W, b) - update!(p, -η * grad(p)) + update!(p, -η * grads[p]) end end ``` diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 4cbde1f0..65b8db11 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -5,7 +5,7 @@ using MacroTools: @q, @forward import Base: == -export TrackedArray, TrackedVector, TrackedMatrix, param, back! +export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back! tracker(x) = nothing @@ -61,7 +61,7 @@ macro grad(ex) end function update!(x, Δ) - tracker(x).data += Δ + x.data .+= data(Δ) tracker(x).grad .= 0 return x end diff --git a/src/tracker/back.jl b/src/tracker/back.jl index 8e492861..08cf9d6a 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -72,10 +72,18 @@ end @forward Params.params Base.start, Base.next, Base.done +function Base.show(io::IO, ps::Params) + print(io, "Params([") + join(io, ps.params, ", ") + print(io, "])") +end + struct Grads grads::ObjectIdDict end +Base.show(io::IO, ps::Grads) = println(io, "Grads(...)") + Grads() = Grads(ObjectIdDict()) Grads(ps::Params) = Grads(ObjectIdDict(tracker(p) => init_grad(data(p)) for p in ps))