Update CUDNN Batchnorm with new Flux AD

This commit is contained in:
Avik Pal 2018-07-17 09:40:20 +05:30
parent 646db81f94
commit 0bb3eaa1f6
2 changed files with 36 additions and 136 deletions

View File

@ -123,7 +123,7 @@ function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T
dx = similar(x)
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum),
training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
(dx, db, dg)
(dg, db, dx)
end
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
@ -176,94 +176,22 @@ end
# Flux Interface
<<<<<<< HEAD
import ..Flux: Flux
import ..Tracker: track, back, @back, istracked, TrackedArray
(BN::Flux.BatchNorm)(x::Union{CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)
=======
function desc(rnn)
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
copyparams!(rnn, d)
return d
end
batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)
import Flux.Tracker
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
forward(desc(m), x, h[1], h[2])
return (result[2], result[3]), result[1]
end
>>>>>>> 071dcdda879a74cfd3c1115ac2c92087b38d4ae9
_batchnorm(g, b, x, running_mean, running_var, momentum,
cache, alpha, beta, eps, training) =
batchnorm(g, b, x, running_mean, running_var, momentum, cache = cache,
alpha = alpha, beta = beta, eps = eps, training = training)
<<<<<<< HEAD
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum;
cache = nothing, alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64} =
track(_batchnorm, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training)
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; cache = nothing, alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64} =
track(_batchnorm, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training)
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
track(_batchnorm, g, b, x, running_mean, running_var, momentum, kw...)
function back(::typeof(_batchnorm), Δ, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training)
@grad function batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
y = batchnorm(data(g), data(b), data(x), running_mean, running_var, momentum; kw...)
deriv_tup = ∇batchnorm(data(g), data(b), data(x), Δ, running_mean, running_var, momentum,
cache = cache, alpha = alpha, beta = beta, eps = eps, training = training)
@back(x, deriv_tup[1])
@back(b, deriv_tup[2])
@back(g, deriv_tup[3])
=======
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data(x), data(h))
result, function (Δ)
y, ho = result
dy, dho = Δ
h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db))
end
end
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
result, function (Δ)
y, ho = result
dy, dho, dco = Δ
h_ = hBatch(x, data(h))
c_ = hBatch(x, data(c))
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN,
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
dWi.', dWh.', db))
end
>>>>>>> 071dcdda879a74cfd3c1115ac2c92087b38d4ae9
end
y, Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.(g, b, x, Δ), running_mean, running_var, momentum; kw...)), nothing, nothing, nothing)

View File

@ -265,41 +265,28 @@ function desc(rnn)
return d
end
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast
mutable struct RNNCall{R}
rnn::R
reserve::CuVector{UInt8}
RNNCall{R}(rnn::R) where R = new(rnn)
end
RNNCall(rnn) = RNNCall{typeof(rnn)}(rnn)
function (c::RNNCall)(args...)
rs, result = forwardTrain(desc(c.rnn), args...)
c.reserve = rs
return result
end
import Flux.Tracker
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(RNNCall(m), x, h) :
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(RNNCall(m), x, h) :
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(RNNCall(m), x, h[1], h[2]) :
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
forward(desc(m), x, h[1], h[2])
return (result[2], result[3]), result[1]
end
@ -308,44 +295,29 @@ end
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
function accum_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src)
I = @cuindex dst
dst[I...] += src[reverse(I)...]
return
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data(x), data(h))
result, function (Δ)
y, ho = result
dy, dho = Δ
h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db))
end
blk, thr = cudims(dst)
@cuda (blk, thr) kernel(dst, src)
return dst
end
function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h)
y, ho = y_
dy, dho = Δ
h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve)
@back(x, dx)
@back(h, unbroadcast(h, dh))
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
# We don't have to make this assumption, it's just slightly more complex.
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
end
function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c)
y, ho, co = y_
dy, dho, dco = Δ
h_ = hBatch(x, data(h))
c_ = hBatch(x, data(c))
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve)
@back(x, dx)
@back(h, unbroadcast(h, dh))
@back(c, unbroadcast(h, dc))
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
result, function (Δ)
y, ho = result
dy, dho, dco = Δ
h_ = hBatch(x, data(h))
c_ = hBatch(x, data(c))
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN,
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
dWi.', dWh.', db))
end
end