diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 28a9eec3..85b5b975 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -286,41 +286,28 @@ function desc(rnn) return d end -import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast - -mutable struct RNNCall{R} - rnn::R - reserve::CuVector{UInt8} - RNNCall{R}(rnn::R) where R = new(rnn) -end - -RNNCall(rnn) = RNNCall{typeof(rnn)}(rnn) - -function (c::RNNCall)(args...) - rs, result = forwardTrain(desc(c.rnn), args...) - c.reserve = rs - return result -end +import Flux.Tracker +import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...)) function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} result = istrain(m, h, x) ? - track(RNNCall(m), x, h) : + track(m, x, h, m.Wi, m.Wh, m.b) : forward(desc(m), x, h) return result[2], result[1] end function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} result = istrain(m, h, x) ? - track(RNNCall(m), x, h) : + track(m, x, h, m.Wi, m.Wh, m.b) : forward(desc(m), x, h) return result[2], result[1] end function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64} result = istrain(m, h, x) ? - track(RNNCall(m), x, h[1], h[2]) : + track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) : forward(desc(m), x, h[1], h[2]) return (result[2], result[3]), result[1] end @@ -329,44 +316,29 @@ end (m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -function accum_transpose!(dst::CuArray, src::CuArray) - function kernel(dst, src) - I = @cuindex dst - dst[I...] += src[reverse(I)...] - return +@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b) + reserve, result = forwardTrain(desc(m), data(x), data(h)) + result, function (Δ) + y, ho = result + dy, dho = Δ + h_ = hBatch(x, data(h)) + dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) + (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) + nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db)) end - blk, thr = cudims(dst) - @cuda (blk, thr) kernel(dst, src) - return dst end -# function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h) -# y, ho = y_ -# dy, dho = Δ -# h_ = hBatch(x, data(h)) -# dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve) -# @back(x, dx) -# @back(h, unbroadcast(h, dh)) -# (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve) -# # We don't have to make this assumption, it's just slightly more complex. -# @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) -# istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) -# istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh) -# istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db) -# end - -# function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c) -# y, ho, co = y_ -# dy, dho, dco = Δ -# h_ = hBatch(x, data(h)) -# c_ = hBatch(x, data(c)) -# dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve) -# @back(x, dx) -# @back(h, unbroadcast(h, dh)) -# @back(c, unbroadcast(h, dc)) -# (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve) -# @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) -# istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) -# istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh) -# istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db) -# end +@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b) + reserve, result = forwardTrain(desc(m), data.((x, h, c))...) + result, function (Δ) + y, ho = result + dy, dho, dco = Δ + h_ = hBatch(x, data(h)) + c_ = hBatch(x, data(c)) + dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) + (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) + nobacksies(:RNN, + (dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc), + dWi.', dWh.', db)) + end +end diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 2a94edb7..4cbde1f0 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -55,7 +55,8 @@ macro grad(ex) @capture(shortdef(ex), (name_(args__) = body_) | (name_(args__) where {T__} = body_)) || error("Need a function definition") T == nothing && (T = []) - insert!(args, 1+isexpr(args[1], :parameters) , :(::typeof($name))) + isexpr(name, :(::)) || (name = :(::typeof($name))) + insert!(args, 1+isexpr(args[1], :parameters) , name) @q(Tracker._forward($(args...)) where $(T...) = $body) |> esc end diff --git a/src/tracker/array.jl b/src/tracker/array.jl index e034a868..6c7f93e3 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -66,7 +66,7 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y) Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) -@grad function getindex(xs, i...) +@grad function getindex(xs::AbstractArray, i...) data(xs)[i...], function (Δ) Δ′ = zero(xs) Δ′[i...] = data(Δ) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 7e574fd9..50b9c7af 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -79,13 +79,14 @@ struct TrackedTuple{T<:Tuple} tracker::Tracked{T} end +data(xs::TrackedTuple) = xs.data tracker(xs::TrackedTuple) = xs.tracker accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ) init_grad(x::Tuple) = init_grad.(x) zero_grad!(x::Tuple) = zero_grad!.(x) -track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f)) +track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs))) function Base.show(io::IO, xs::TrackedTuple) show(io, data(xs)) @@ -96,8 +97,9 @@ Base.length(x::TrackedTuple) = length(data(x)) Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i) -back(::typeof(getindex), Δ, t, i) = - back(t, ntuple(j -> i == j ? Δ : 0, length(t))) +@grad function getindex(xs::TrackedTuple, i) + data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing) +end # Array collection