diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 92debbdd..20130b1d 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -123,7 +123,7 @@ function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T dx = similar(x) cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum), training = training, cache = cache, eps = eps, alpha = alpha, beta = beta) - (dx, db, dg) + (dg, db, dx) end function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, @@ -176,94 +176,22 @@ end # Flux Interface -<<<<<<< HEAD import ..Flux: Flux import ..Tracker: track, back, @back, istracked, TrackedArray (BN::Flux.BatchNorm)(x::Union{CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} = - batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active) -======= -function desc(rnn) - d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn)) - copyparams!(rnn, d) - return d -end + batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active) -import Flux.Tracker -import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies - -istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...)) - -function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} - result = istrain(m, h, x) ? - track(m, x, h, m.Wi, m.Wh, m.b) : - forward(desc(m), x, h) - return result[2], result[1] -end - -function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} - result = istrain(m, h, x) ? - track(m, x, h, m.Wi, m.Wh, m.b) : - forward(desc(m), x, h) - return result[2], result[1] -end - -function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64} - result = istrain(m, h, x) ? - track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) : - forward(desc(m), x, h[1], h[2]) - return (result[2], result[3]), result[1] -end ->>>>>>> 071dcdda879a74cfd3c1115ac2c92087b38d4ae9 - -_batchnorm(g, b, x, running_mean, running_var, momentum, - cache, alpha, beta, eps, training) = - batchnorm(g, b, x, running_mean, running_var, momentum, cache = cache, - alpha = alpha, beta = beta, eps = eps, training = training) - -<<<<<<< HEAD batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T}, - running_var::CuArray{T}, momentum; - cache = nothing, alpha = T(1), beta = T(0), - eps = T(1e-5), training = true) where T<:Union{Float32, Float64} = - track(_batchnorm, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training) + running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = + track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T}, - running_var::CuArray{T}, momentum; cache = nothing, alpha = T(1), beta = T(0), - eps = T(1e-5), training = true) where T<:Union{Float32, Float64} = - track(_batchnorm, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training) + running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = + track(_batchnorm, g, b, x, running_mean, running_var, momentum, kw...) -function back(::typeof(_batchnorm), Δ, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training) +@grad function batchnorm(g, b, x, running_mean, running_var, momentum; kw...) + y = batchnorm(data(g), data(b), data(x), running_mean, running_var, momentum; kw...) deriv_tup = ∇batchnorm(data(g), data(b), data(x), Δ, running_mean, running_var, momentum, cache = cache, alpha = alpha, beta = beta, eps = eps, training = training) - @back(x, deriv_tup[1]) - @back(b, deriv_tup[2]) - @back(g, deriv_tup[3]) -======= -@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b) - reserve, result = forwardTrain(desc(m), data(x), data(h)) - result, function (Δ) - y, ho = result - dy, dho = Δ - h_ = hBatch(x, data(h)) - dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) - (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) - nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db)) - end -end - -@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b) - reserve, result = forwardTrain(desc(m), data.((x, h, c))...) - result, function (Δ) - y, ho = result - dy, dho, dco = Δ - h_ = hBatch(x, data(h)) - c_ = hBatch(x, data(c)) - dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) - (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) - nobacksies(:RNN, - (dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc), - dWi.', dWh.', db)) - end ->>>>>>> 071dcdda879a74cfd3c1115ac2c92087b38d4ae9 -end + y, Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.(g, b, x, Δ), running_mean, running_var, momentum; kw...)), nothing, nothing, nothing) diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index 905b1ef4..ed65f5e7 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -265,41 +265,28 @@ function desc(rnn) return d end -import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast - -mutable struct RNNCall{R} - rnn::R - reserve::CuVector{UInt8} - RNNCall{R}(rnn::R) where R = new(rnn) -end - -RNNCall(rnn) = RNNCall{typeof(rnn)}(rnn) - -function (c::RNNCall)(args...) - rs, result = forwardTrain(desc(c.rnn), args...) - c.reserve = rs - return result -end +import Flux.Tracker +import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...)) function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} result = istrain(m, h, x) ? - track(RNNCall(m), x, h) : + track(m, x, h, m.Wi, m.Wh, m.b) : forward(desc(m), x, h) return result[2], result[1] end function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} result = istrain(m, h, x) ? - track(RNNCall(m), x, h) : + track(m, x, h, m.Wi, m.Wh, m.b) : forward(desc(m), x, h) return result[2], result[1] end function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64} result = istrain(m, h, x) ? - track(RNNCall(m), x, h[1], h[2]) : + track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) : forward(desc(m), x, h[1], h[2]) return (result[2], result[3]), result[1] end @@ -308,44 +295,29 @@ end (m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -function accum_transpose!(dst::CuArray, src::CuArray) - function kernel(dst, src) - I = @cuindex dst - dst[I...] += src[reverse(I)...] - return +@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b) + reserve, result = forwardTrain(desc(m), data(x), data(h)) + result, function (Δ) + y, ho = result + dy, dho = Δ + h_ = hBatch(x, data(h)) + dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) + (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) + nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db)) end - blk, thr = cudims(dst) - @cuda (blk, thr) kernel(dst, src) - return dst end -function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h) - y, ho = y_ - dy, dho = Δ - h_ = hBatch(x, data(h)) - dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve) - @back(x, dx) - @back(h, unbroadcast(h, dh)) - (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve) - # We don't have to make this assumption, it's just slightly more complex. - @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) - istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) - istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh) - istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db) -end - -function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c) - y, ho, co = y_ - dy, dho, dco = Δ - h_ = hBatch(x, data(h)) - c_ = hBatch(x, data(c)) - dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve) - @back(x, dx) - @back(h, unbroadcast(h, dh)) - @back(c, unbroadcast(h, dc)) - (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve) - @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) - istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) - istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh) - istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db) +@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b) + reserve, result = forwardTrain(desc(m), data.((x, h, c))...) + result, function (Δ) + y, ho = result + dy, dho, dco = Δ + h_ = hBatch(x, data(h)) + c_ = hBatch(x, data(c)) + dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) + (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) + nobacksies(:RNN, + (dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc), + dWi.', dWh.', db)) + end end