diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index 0065f17b..82982180 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -7,6 +7,12 @@ if !applicable(CuArray{UInt8}, undef, 1) end if CuArrays.libcudnn != nothing + if isdefined(CuArrays, :libcudnn_handle) + handle() = CuArrays.libcudnn_handle[] + else + handle() = CuArrays.CUDNN.handle() + end + include("curnn.jl") include("cudnn.jl") else @warn("CUDNN is not installed, some functionality will not be available.") diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index edb96449..8bd8135e 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -1,13 +1,8 @@ -using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, - cudnnDataType, TensorDesc, FilterDesc +using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t, + cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc +import ..Flux: data using LinearAlgebra -if isdefined(CuArrays, :libcudnn_handle) - handle() = CuArrays.libcudnn_handle[] -else - handle() = CuArrays.CUDNN.handle() -end - mutable struct DropoutDesc ptr::Ptr{Nothing} states::CuVector{UInt8} @@ -30,324 +25,204 @@ function DropoutDesc(ρ::Real; seed::Integer=0) return desc end -const RNN_RELU = 0 # Stock RNN with ReLu activation -const RNN_TANH = 1 # Stock RNN with tanh activation -const LSTM = 2 # LSTM with no peephole connections -const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1) +const BATCHNORM_SPATIAL = 1 +const BATCHNORM_ACTIVATION = 0 +const BATCHNORM_MIN_EPS = 1e-5 -const LINEAR_INPUT = 0 -const SKIP_INPUT = 1 +@inline _wsize(y) = (map(_ -> 1, size(y)[1:end-2])..., size(y)[end-1], 1) -const UNIDIRECTIONAL = 0 -const BIDIRECTIONAL = 1 +@inline _reddims(y) = (collect(1:ndims(y)-2)..., ndims(y)) -const RNN_ALGO_STANDARD = 0 -const RNN_ALGO_PERSIST_STATIC = 1 -const RNN_ALGO_PERSIST_DYNAMIC = 2 - -# param layout: -# RNN: [weight, bias] × [input, hidden] -# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem] -# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output] - -function params(w::CuVector, input, hidden, n = 1) - slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape) - wx = slice(0, (input, hidden*n)) - wh = slice(length(wx), (hidden, hidden*n)) - bias = view(w, length(wx)+length(wh) .+ (1:hidden*n)) - (wx, wh), bias +mutable struct BNCache + mean + ivar end -mutable struct RNNDesc{T} - mode::Int - input::Int - hidden::Int - params::CuVector{T} - weights::NTuple{2,CuMatrix{T}} - bias::CuVector{T} - ptr::Ptr{Nothing} +BNCache() = BNCache(nothing, nothing) + +# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations +# so reshape a 2D Tensor into 4D +batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, + running_mean::CuArray{T}, running_var::CuArray{T}, momentum; + cache = nothing, alpha = T(1), beta = T(0), + eps = T(1e-5), training = true) where T<:Union{Float32, Float64} = + dropdims(batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), running_mean, running_var, momentum, + cache = cache, alpha = alpha, beta = beta, eps = eps, training = training), dims = (1, 2)) + +function batchnorm(g::CuArray{T}, b::CuArray{T}, x::Union{CuArray{T, 4},CuArray{T,5}}, + running_mean::CuArray{T}, running_var::CuArray{T}, momentum; + cache = nothing, alpha = T(1), beta = T(0), + eps = T(1e-5), training = true) where T<:Union{Float32, Float64} + y = similar(x) + cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache, + alpha = alpha, beta = beta, eps = eps, training = training) + y end -Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr - -function rnnParamSize(T, r, input) - size = Csize_t[0] - @check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint), - handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T)) - return Int(size[])÷sizeof(T) -end - -ngates(mode) = [1, 1, 4, 3][mode+1] -ngates(r::RNNDesc) = ngates(r.mode) - -function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T - d = [C_NULL] - @check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d) - - dropoutDesc = DropoutDesc(0) - inputMode = LINEAR_INPUT - direction = UNIDIRECTIONAL - algo = RNN_ALGO_STANDARD - @check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint), - handle(),d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T)) - - w = cuzeros(T, rnnParamSize(T, d[], input)) - # TODO: avoid reserve allocation here - rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[]) - finalizer(rd) do x - @check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x) +function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, + running_mean::CuArray{T}, running_var::CuArray{T}, + momentum; cache = nothing, + alpha = T(1), beta = T(0), + eps = T(1e-5), training = true) where T<:Union{Float32, Float64} + dims = _wsize(x) + if eps < BATCHNORM_MIN_EPS + # warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS) + eps = BATCHNORM_MIN_EPS end - return rd -end + xd = TensorDesc(x) + yd = TensorDesc(y) + gd = TensorDesc(T, dims) -function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc) - size = Csize_t[0] - @check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}), - handle(), r, seqlen, xdesc, size) - return Int(size[]) -end + if training -const workspace = [CuVector{UInt8}(undef, 1)] + if cache !== nothing + mean = zeros(CuArray{T}, dims...) + ivar = ones(CuArray{T}, dims...) + else + mean = C_NULL + ivar = C_NULL + end -getworkspace(bytes) = - length(workspace[]) ≥ bytes ? - workspace[] : - (workspace[] = CuVector{UInt8}(undef, bytes)) - -getworkspace(r::RNNDesc, seqlen, xdesc) = - getworkspace(rnnWorkspaceSize(r, seqlen, xdesc)) - -function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc) - size = Csize_t[0] - @check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}), - handle(), r, seqlen, xdesc, size) - return Int(size[]) -end - -function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, - workspace, reserve=nothing) where T - if reserve == nothing - @check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t, - (Ptr{Nothing}, Ptr{Nothing}, Cint, - Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, + @check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t, + (cudnnHandle_t,cudnnBatchNormMode_t, + Ptr{T}, Ptr{T}, Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Csize_t), - handle(), rnn, seqlen, - xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, - workspace, length(workspace)) + Ptr{Nothing}, Ptr{T}, + Ptr{Nothing}, Ptr{T}, Ptr{T}, + Cdouble, Ptr{T}, Ptr{T}, + Cdouble, Ptr{T}, Ptr{T}), + handle(), BATCHNORM_SPATIAL, + Ref(T(alpha)), Ref(T(beta)), + xd, x, + yd, y, + gd, g, b, + momentum, running_mean, running_var, + eps, mean, ivar) + + if cache !== nothing + cache.mean = mean + cache.ivar = ivar + end else - @check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t, - (Ptr{Nothing}, Ptr{Nothing}, Cint, - Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t), - handle(), rnn, seqlen, - xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, - workspace, length(workspace), reserve, length(reserve)) + @check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t, + (Ptr{cudnnHandle_t},cudnnBatchNormMode_t, + Ptr{T}, Ptr{T}, + Ptr{Nothing}, Ptr{T}, + Ptr{Nothing}, Ptr{T}, + Ptr{Nothing}, Ptr{T}, Ptr{T}, + Ptr{T}, Ptr{T}, + Cdouble), + handle(), BATCHNORM_SPATIAL, + Ref(T(alpha)), Ref(T(beta)), + xd, x, + yd, y, + gd, g, b, + running_mean, running_var, + eps) end end -xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))] - -hDesc(h::Nothing) = C_NULL, C_NULL -hDesc(x::Integer) = (@assert x == 0; hDesc(nothing)) -function hDesc(h::CuArray) - TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h +function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T, 2}, + running_mean::CuArray{T}, running_var::CuArray{T}, momentum; + cache = nothing, eps = T(1e-5), alpha = T(1), + beta = T(0), training = true) where T<:Union{Float32, Float64} + dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1), + size(dy, 2)), running_mean, running_var, momentum, cache = cache, eps = eps, + alpha = alpha, beta = beta, training = training) + (dg, db, dropdims(dx, dims = (1, 2))) end -# TODO: can we just manipulate strides here? -# TODO: should use repmat, but this isn't implemented. -hBatch(x::AbstractVector, h::CuVector) = h -hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2)) -hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1) - -function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T - h = hBatch(x, h_) - c = c_ == nothing ? nothing : hBatch(x, c_) - @assert size(x, 1) == rnn.input - @assert size(h, 1) == rnn.hidden - @assert size(x, 2) == size(h, 2) - seqLength = 1 - xdesc = xDesc(x) - y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2)) - ho = similar(h) - ydesc = xDesc(y) - workspace = getworkspace(rnn, seqLength, xdesc) - reserve = train == Val{true} ? - CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) : - nothing - co = c == nothing ? c : similar(c) - cudnnRNNForward(rnn, seqLength, - xdesc, x, - hDesc(h)..., - hDesc(c)..., - FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, - ydesc, y, - hDesc(ho)..., - hDesc(co)..., - workspace, reserve) - result = c == nothing ? (y, ho) : (y, ho, co) - return train == Val{true} ? (reserve, result) : result +function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, + running_mean::CuArray{T}, running_var::CuArray{T}, momentum; + cache = nothing, eps = T(1e-5), alpha = T(1), + beta = T(0), training = true) where T<:Union{Float32, Float64} + dg = similar(g) + db = similar(b) + dx = similar(x) + cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum), + training = training, cache = cache, eps = eps, alpha = alpha, beta = beta) + (dg, db, dx) end -forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T = - forward(rnn, x, h, c, Val{true}) +function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, + dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, + running_mean::CuArray{T}, running_var::CuArray{T}, + momentum; cache = nothing, eps = T(1e-5), + alpha = T(1), beta = T(0), + dalpha = T(1), dbeta = T(0), training = true) where T<:Union{Float32, Float64} + if training + xd = TensorDesc(x) + dyd = TensorDesc(dy) + dxd = TensorDesc(dx) + gd = TensorDesc(T, _wsize(x)) + if cache !== nothing + mean, ivar = cache.mean, cache.ivar + info("mean and ivar are fetched from the cache") + else + mean, ivar = C_NULL, C_NULL + end -function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, - wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T - @check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t, - (Ptr{Nothing}, Ptr{Nothing}, Cint, - Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, - Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, - Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t), - handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, - wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs)) -end + if eps < BATCHNORM_MIN_EPS + eps = BATCHNORM_MIN_EPS + end -function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T - # Same as above, any more efficient way? - dy = dy_ isa Integer ? zero(y) : dy_ - yd = xDesc(y) - dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2)) - dh = similar(h) - dc = c == nothing ? nothing : similar(c) - cudnnRNNBackwardData(rnn, 1, - yd, y, yd, dy, hDesc(dho)..., hDesc(dco)..., - FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, - hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)..., - workspace[], reserve) - return c == nothing ? (dx, dh) : (dx, dh, dc) -end - -backwardData(rnn, y, dy, dho, hx, reserve) = - backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve) - -function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw, - workspace, reserve) where T - @check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t, - (Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength - Ptr{Ptr{Nothing}}, Ptr{T}, #x - Ptr{Nothing}, Ptr{T}, #hx - Ptr{Ptr{Nothing}}, Ptr{T}, #y - Ptr{Nothing}, Csize_t, #ws - Ptr{Nothing}, Ptr{T}, #dw - Ptr{Nothing}, Csize_t), #rs - handle(), rnn, seqlen, xd, x, hd, h, yd, y, - workspace, length(workspace), dwd, dw, reserve, length(reserve)) -end - -function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T - dw = zero(rnn.params) - cudnnRNNBackwardWeights(rnn, 1, - xDesc(x), x, hDesc(h)..., xDesc(y), y, - FilterDesc(T, (1, 1, length(dw))), dw, - workspace[], reserve) - return params(dw, rnn.input, rnn.hidden, ngates(rnn)) -end - -# Interface - -import ..Flux: Flux, relu -import ..Tracker: TrackedArray -using .CuArrays.CUDAnative -using .CuArrays: @cuindex, cudims - -function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray) - function kernel(dst, src) - I = @cuindex dst - dst[I...] = src[reverse(I)...] - return - end - blk, thr = cudims(dst) - @cuda blocks=blk threads=thr kernel(dst, src) - return dst -end - -CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}} -CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}} -CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}} -CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}} -CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}} - -function copyparams!(m::CuRNNs, d::RNNDesc) - Wi, Wh = d.weights - copy_transpose!(Wi, Flux.data(m.Wi)) - copy_transpose!(Wh, Flux.data(m.Wh)) - copy_transpose!(d.bias, Flux.data(m.b)) - return -end - -function RNNDesc(m::CuRNNs{T}) where T - h, i = length(m.h), size(m.Wi, 2) - mode = m isa CuRNN ? - (m.σ == tanh ? RNN_TANH : RNN_RELU) : - m isa CuGRU ? GRU : LSTM - r = RNNDesc{T}(mode, i, h) - return r -end - -const descs = WeakKeyDict() - -function desc(rnn) - d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn)) - copyparams!(rnn, d) - return d -end - -import Flux.Tracker -import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies - -istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...)) - -function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} - result = istrain(m, h, x) ? - track(m, x, h, m.Wi, m.Wh, m.b) : - forward(desc(m), x, h) - return result[2], result[1] -end - -function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} - result = istrain(m, h, x) ? - track(m, x, h, m.Wi, m.Wh, m.b) : - forward(desc(m), x, h) - return result[2], result[1] -end - -function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64} - result = istrain(m, h, x) ? - track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) : - forward(desc(m), x, h[1], h[2]) - return (result[2], result[3]), result[1] -end - -(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) -(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) - -@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b) - reserve, result = forwardTrain(desc(m), data(x), data(h)) - result, function (Δ) - y, ho = result - dy, dho = Δ - h_ = hBatch(x, data(h)) - dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) - (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) - nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)) + @check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t, + (cudnnHandle_t,cudnnBatchNormMode_t, + Ptr{T}, Ptr{T}, + Ptr{T}, Ptr{T}, + Ptr{Nothing}, Ptr{T}, + Ptr{Nothing}, Ptr{T}, + Ptr{Nothing}, Ptr{T}, + Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{T}, + Cdouble, Ptr{T}, Ptr{T}), + handle(), BATCHNORM_SPATIAL, + Ref(T(alpha)), Ref(T(beta)), + Ref(T(dalpha)), Ref(T(dbeta)), + xd, x, + dyd, dy, + dxd, dx, + gd, g, dg, db, + eps, mean, ivar) + else + ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps) + dx .= dy .* reshape(g, _wsize(x)) .* ivar + dg .= squeeze(sum(dy .* (x .- reshape(running_mean, _wsize(x))) .* ivar, _reddims(dy)), dims = (1,2,4)) + db .= squeeze(sum(dy, _reddims(dy)), dims = (1,2,4)) end end -@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b) - reserve, result = forwardTrain(desc(m), data.((x, h, c))...) - result, function (Δ) - y, ho = result - dy, dho, dco = Δ - h_ = hBatch(x, data(h)) - c_ = hBatch(x, data(c)) - dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) - (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) - nobacksies(:RNN, - (dx, unbroadcast(h, dh), unbroadcast(c, dc), - transpose(dWi), transpose(dWh), db)) - end -end +# Flux Interface + +(BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} = + batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active) + +batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T}, + running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = + track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + +batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T}, + running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = + track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + +batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T}, + running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = + track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + +batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T}, + running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = + track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + +batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T}, + running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = + track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + +batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T}, + running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = + track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + +batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T}, + running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = + track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + +@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) = + batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing) diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl new file mode 100644 index 00000000..210ddd7c --- /dev/null +++ b/src/cuda/curnn.jl @@ -0,0 +1,325 @@ +using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t, + cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc +using LinearAlgebra + +const RNN_RELU = 0 # Stock RNN with ReLu activation +const RNN_TANH = 1 # Stock RNN with tanh activation +const LSTM = 2 # LSTM with no peephole connections +const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1) + +const LINEAR_INPUT = 0 +const SKIP_INPUT = 1 + +const UNIDIRECTIONAL = 0 +const BIDIRECTIONAL = 1 + +const RNN_ALGO_STANDARD = 0 +const RNN_ALGO_PERSIST_STATIC = 1 +const RNN_ALGO_PERSIST_DYNAMIC = 2 + +# param layout: +# RNN: [weight, bias] × [input, hidden] +# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem] +# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output] + +function params(w::CuVector, input, hidden, n = 1) + slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape) + wx = slice(0, (input, hidden*n)) + wh = slice(length(wx), (hidden, hidden*n)) + bias = view(w, length(wx)+length(wh) .+ (1:hidden*n)) + (wx, wh), bias +end + +mutable struct RNNDesc{T} + mode::Int + input::Int + hidden::Int + params::CuVector{T} + weights::NTuple{2,CuMatrix{T}} + bias::CuVector{T} + ptr::Ptr{Nothing} +end + +Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr + +function rnnParamSize(T, r, input) + size = Csize_t[0] + @check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint), + handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T)) + return Int(size[])÷sizeof(T) +end + +ngates(mode) = [1, 1, 4, 3][mode+1] +ngates(r::RNNDesc) = ngates(r.mode) + +function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T + d = [C_NULL] + @check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d) + + dropoutDesc = DropoutDesc(0) + inputMode = LINEAR_INPUT + direction = UNIDIRECTIONAL + algo = RNN_ALGO_STANDARD + @check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint), + handle(),d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T)) + + w = cuzeros(T, rnnParamSize(T, d[], input)) + # TODO: avoid reserve allocation here + rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[]) + finalizer(rd) do x + @check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x) + end + return rd +end + +function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc) + size = Csize_t[0] + @check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}), + handle(), r, seqlen, xdesc, size) + return Int(size[]) +end + +const workspace = [CuVector{UInt8}(undef, 1)] + +getworkspace(bytes) = + length(workspace[]) ≥ bytes ? + workspace[] : + (workspace[] = CuVector{UInt8}(undef, bytes)) + +getworkspace(r::RNNDesc, seqlen, xdesc) = + getworkspace(rnnWorkspaceSize(r, seqlen, xdesc)) + +function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc) + size = Csize_t[0] + @check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}), + handle(), r, seqlen, xdesc, size) + return Int(size[]) +end + +function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, + workspace, reserve=nothing) where T + if reserve == nothing + @check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t, + (Ptr{Nothing}, Ptr{Nothing}, Cint, + Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, + Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, + Ptr{Nothing}, Ptr{T}, + Ptr{Nothing}, Csize_t), + handle(), rnn, seqlen, + xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, + workspace, length(workspace)) + else + @check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t, + (Ptr{Nothing}, Ptr{Nothing}, Cint, + Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, + Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t), + handle(), rnn, seqlen, + xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, + workspace, length(workspace), reserve, length(reserve)) + end +end + +xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))] + +hDesc(h::Nothing) = C_NULL, C_NULL +hDesc(x::Integer) = (@assert x == 0; hDesc(nothing)) +function hDesc(h::CuArray) + TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h +end + +# TODO: can we just manipulate strides here? +# TODO: should use repmat, but this isn't implemented. +hBatch(x::AbstractVector, h::CuVector) = h +hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2)) +hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1) + +function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T + h = hBatch(x, h_) + c = c_ == nothing ? nothing : hBatch(x, c_) + @assert size(x, 1) == rnn.input + @assert size(h, 1) == rnn.hidden + @assert size(x, 2) == size(h, 2) + seqLength = 1 + xdesc = xDesc(x) + y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2)) + ho = similar(h) + ydesc = xDesc(y) + workspace = getworkspace(rnn, seqLength, xdesc) + reserve = train == Val{true} ? + CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) : + nothing + co = c == nothing ? c : similar(c) + cudnnRNNForward(rnn, seqLength, + xdesc, x, + hDesc(h)..., + hDesc(c)..., + FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, + ydesc, y, + hDesc(ho)..., + hDesc(co)..., + workspace, reserve) + result = c == nothing ? (y, ho) : (y, ho, co) + return train == Val{true} ? (reserve, result) : result +end + +forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T = + forward(rnn, x, h, c, Val{true}) + +function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, + wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T + @check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t, + (Ptr{Nothing}, Ptr{Nothing}, Cint, + Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, + Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, + Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, + Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t), + handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, + wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs)) +end + +function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T + # Same as above, any more efficient way? + dy = dy_ isa Integer ? zero(y) : dy_ + yd = xDesc(y) + dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2)) + dh = similar(h) + dc = c == nothing ? nothing : similar(c) + cudnnRNNBackwardData(rnn, 1, + yd, y, yd, dy, hDesc(dho)..., hDesc(dco)..., + FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, + hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)..., + workspace[], reserve) + return c == nothing ? (dx, dh) : (dx, dh, dc) +end + +backwardData(rnn, y, dy, dho, hx, reserve) = + backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve) + +function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw, + workspace, reserve) where T + @check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t, + (Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength + Ptr{Ptr{Nothing}}, Ptr{T}, #x + Ptr{Nothing}, Ptr{T}, #hx + Ptr{Ptr{Nothing}}, Ptr{T}, #y + Ptr{Nothing}, Csize_t, #ws + Ptr{Nothing}, Ptr{T}, #dw + Ptr{Nothing}, Csize_t), #rs + handle(), rnn, seqlen, xd, x, hd, h, yd, y, + workspace, length(workspace), dwd, dw, reserve, length(reserve)) +end + +function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T + dw = zero(rnn.params) + cudnnRNNBackwardWeights(rnn, 1, + xDesc(x), x, hDesc(h)..., xDesc(y), y, + FilterDesc(T, (1, 1, length(dw))), dw, + workspace[], reserve) + return params(dw, rnn.input, rnn.hidden, ngates(rnn)) +end + +# Interface + +import ..Flux: Flux, relu +import ..Tracker: TrackedArray +using .CuArrays.CUDAnative +using .CuArrays: @cuindex, cudims + +function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray) + function kernel(dst, src) + I = @cuindex dst + dst[I...] = src[reverse(I)...] + return + end + blk, thr = cudims(dst) + @cuda blocks=blk threads=thr kernel(dst, src) + return dst +end + +CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}} +CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}} +CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}} +CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}} +CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}} + +function copyparams!(m::CuRNNs, d::RNNDesc) + Wi, Wh = d.weights + copy_transpose!(Wi, Flux.data(m.Wi)) + copy_transpose!(Wh, Flux.data(m.Wh)) + copy_transpose!(d.bias, Flux.data(m.b)) + return +end + +function RNNDesc(m::CuRNNs{T}) where T + h, i = length(m.h), size(m.Wi, 2) + mode = m isa CuRNN ? + (m.σ == tanh ? RNN_TANH : RNN_RELU) : + m isa CuGRU ? GRU : LSTM + r = RNNDesc{T}(mode, i, h) + return r +end + +const descs = WeakKeyDict() + +function desc(rnn) + d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn)) + copyparams!(rnn, d) + return d +end + +import Flux.Tracker +import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies + +istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...)) + +function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} + result = istrain(m, h, x) ? + track(m, x, h, m.Wi, m.Wh, m.b) : + forward(desc(m), x, h) + return result[2], result[1] +end + +function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} + result = istrain(m, h, x) ? + track(m, x, h, m.Wi, m.Wh, m.b) : + forward(desc(m), x, h) + return result[2], result[1] +end + +function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64} + result = istrain(m, h, x) ? + track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) : + forward(desc(m), x, h[1], h[2]) + return (result[2], result[3]), result[1] +end + +(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) +(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) +(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) + +@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b) + reserve, result = forwardTrain(desc(m), data(x), data(h)) + result, function (Δ) + y, ho = result + dy, dho = Δ + h_ = hBatch(x, data(h)) + dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) + (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) + nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)) + end +end + +@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b) + reserve, result = forwardTrain(desc(m), data.((x, h, c))...) + result, function (Δ) + y, ho = result + dy, dho, dco = Δ + h_ = hBatch(x, data(h)) + c_ = hBatch(x, data(c)) + dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) + (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve) + nobacksies(:RNN, + (dx, unbroadcast(h, dh), unbroadcast(c, dc), + transpose(dWi), transpose(dWh), db)) + end +end diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 164f6fa7..9201e991 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -44,7 +44,6 @@ end _testmode!(a::Dropout, test) = (a.active = !test) """ - LayerNorm(h::Integer) A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be @@ -86,7 +85,6 @@ See [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf). Example: - ```julia m = Chain( Dense(28^2, 64), @@ -101,14 +99,14 @@ mutable struct BatchNorm{F,V,W,N} β::V # bias γ::V # scale μ::W # moving mean - σ::W # moving std + σ²::W # moving std ϵ::N momentum::N active::Bool end BatchNorm(chs::Integer, λ = identity; - initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-8, momentum = .1) = + initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-5, momentum = .1) = BatchNorm(λ, param(initβ(chs)), param(initγ(chs)), zeros(chs), ones(chs), ϵ, momentum, true) @@ -124,31 +122,31 @@ function (BN::BatchNorm)(x) if !BN.active μ = reshape(BN.μ, affine_shape...) - σ = reshape(BN.σ, affine_shape...) + σ² = reshape(BN.σ², affine_shape...) else T = eltype(x) ϵ = data(convert(T, BN.ϵ)) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) μ = mean(x, dims = axes) - σ = sqrt.(mean((x .- μ).^2, dims = axes) .+ ϵ) + σ² = sum((x .- μ) .^ 2, dims = axes) ./ m # update moving mean/std mtm = data(convert(T, BN.momentum)) - BN.μ = (1 - mtm) .* BN.μ .+ mtm .* dropdims(data(μ), dims = (axes...,)) - BN.σ = (1 - mtm) .* BN.σ .+ mtm .* dropdims(data(σ), dims = (axes...,)) .* m ./ (m - 1) + BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :) + BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1)) end let λ = BN.λ - λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...)) + λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ)) .+ reshape(β, affine_shape...)) end end children(BN::BatchNorm) = - (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active) + (BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum, BN.active) mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) - BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active) + BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum, BN.active) _testmode!(BN::BatchNorm, test) = (BN.active = !test) diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 1f54d1b9..e266a81b 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -36,4 +36,8 @@ Flux.back!(sum(l)) end -CuArrays.libcudnn != nothing && include("cudnn.jl") +if CuArrays.libcudnn != nothing + @info "Testing Flux/CUDNN" + include("cudnn.jl") + include("curnn.jl") +end diff --git a/test/cuda/cudnn.jl b/test/cuda/cudnn.jl index d5cf442b..9a154961 100644 --- a/test/cuda/cudnn.jl +++ b/test/cuda/cudnn.jl @@ -1,48 +1,48 @@ -using Flux, CuArrays, Test +using Flux, Flux.Tracker, CuArrays, Test +using Flux.Tracker: TrackedArray, data -@info "Testing Flux/CUDNN" +@testset "CUDNN BatchNorm" begin + @testset "4D Input" begin + x = TrackedArray(Float64.(collect(reshape(1:12, 2, 2, 3, 1)))) + m = BatchNorm(3) + cx = gpu(x) + cm = gpu(m) -@testset "RNN" begin - @testset for R in [RNN, GRU, LSTM] - rnn = R(10, 5) - curnn = mapleaves(gpu, rnn) - @testset for batch_size in (1, 5) - Flux.reset!(rnn) - Flux.reset!(curnn) - x = batch_size == 1 ? - param(rand(10)) : - param(rand(10,batch_size)) - cux = gpu(x) - y = (rnn(x); rnn(x)) - cuy = (curnn(cux); curnn(cux)) + y = m(x) + cy = cm(cx) - @test y.data ≈ collect(cuy.data) - @test haskey(Flux.CUDA.descs, curnn.cell) + @test cy isa TrackedArray{Float32,4,CuArray{Float32,4}} - Δ = randn(size(y)) + @test cpu(data(cy)) ≈ data(y) - Flux.back!(y, Δ) - Flux.back!(cuy, gpu(Δ)) + g = rand(size(y)...) + Flux.back!(y, g) + Flux.back!(cy, gpu(g)) - @test x.grad ≈ collect(cux.grad) - @test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad) - @test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad) - @test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad) - @test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad) - if isdefined(rnn.cell, :c) - @test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad) - end - - Flux.reset!(rnn) - Flux.reset!(curnn) - ohx = batch_size == 1 ? - Flux.onehot(rand(1:10), 1:10) : - Flux.onehotbatch(rand(1:10, batch_size), 1:10) - cuohx = gpu(ohx) - y = (rnn(ohx); rnn(ohx)) - cuy = (curnn(cuohx); curnn(cuohx)) - - @test y.data ≈ collect(cuy.data) + @test m.γ.grad ≈ cpu(cm.γ.grad) + @test m.β.grad ≈ cpu(cm.β.grad) + @test x.grad ≈ cpu(x.grad) + end + + @testset "2D Input" begin + x = TrackedArray(Float64.(collect(reshape(1:12, 3, 4)))) + m = BatchNorm(3) + cx = gpu(x) + cm = gpu(m) + + y = m(x) + cy = cm(cx) + + @test cy isa TrackedArray{Float32,2,CuArray{Float32,2}} + + @test cpu(data(cy)) ≈ data(y) + + g = rand(size(y)...) + Flux.back!(y, g) + Flux.back!(cy, gpu(g)) + + @test m.γ.grad ≈ cpu(cm.γ.grad) + @test m.β.grad ≈ cpu(cm.β.grad) + @test x.grad ≈ cpu(x.grad) end - end end diff --git a/test/cuda/curnn.jl b/test/cuda/curnn.jl new file mode 100644 index 00000000..3f5e1819 --- /dev/null +++ b/test/cuda/curnn.jl @@ -0,0 +1,46 @@ +using Flux, CuArrays, Test + +@testset "RNN" begin + @testset for R in [RNN, GRU, LSTM] + rnn = R(10, 5) + curnn = mapleaves(gpu, rnn) + @testset for batch_size in (1, 5) + Flux.reset!(rnn) + Flux.reset!(curnn) + x = batch_size == 1 ? + param(rand(10)) : + param(rand(10,batch_size)) + cux = gpu(x) + y = (rnn(x); rnn(x)) + cuy = (curnn(cux); curnn(cux)) + + @test y.data ≈ collect(cuy.data) + @test haskey(Flux.CUDA.descs, curnn.cell) + + Δ = randn(size(y)) + + Flux.back!(y, Δ) + Flux.back!(cuy, gpu(Δ)) + + @test x.grad ≈ collect(cux.grad) + @test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad) + @test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad) + @test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad) + @test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad) + if isdefined(rnn.cell, :c) + @test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad) + end + + Flux.reset!(rnn) + Flux.reset!(curnn) + ohx = batch_size == 1 ? + Flux.onehot(rand(1:10), 1:10) : + Flux.onehotbatch(rand(1:10, batch_size), 1:10) + cuohx = gpu(ohx) + y = (rnn(ohx); rnn(ohx)) + cuy = (curnn(cuohx); curnn(cuohx)) + + @test y.data ≈ collect(cuy.data) + end + end +end diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index b17120b0..18276140 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -1,4 +1,5 @@ using Flux: testmode! +using Flux.Tracker: data @testset "Dropout" begin x = [1.,2.,3.] @@ -28,7 +29,8 @@ using Flux: testmode! end @testset "BatchNorm" begin - let m = BatchNorm(2), x = param([1 2; 3 4; 5 6]') + let m = BatchNorm(2), x = param([1 3 5; + 2 4 6]) @test m.β.data == [0, 0] # initβ(2) @test m.γ.data == [1, 1] # initγ(2) @@ -53,29 +55,30 @@ end # .1 * 4 + 0 = .4 @test m.μ ≈ reshape([0.3, 0.4], 2, 1) - # julia> .1 .* std(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + # julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] # 2×1 Array{Float64,2}: - # 1.14495 - # 1.14495 - @test m.σ ≈ .1 .* std(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + # 1.3 + # 1.3 + @test m.σ² ≈ .1 .* var(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] testmode!(m) @test !m.active x′ = m(x).data - @test x′[1] ≈ (1 .- 0.3) / 1.1449489742783179 + @test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5) end # with activation function - let m = BatchNorm(2, σ), x = param([1 2; 3 4; 5 6]') + let m = BatchNorm(2, sigmoid), x = param([1 3 5; + 2 4 6]) @test m.active m(x) testmode!(m) @test !m.active - x′ = m(x).data - @test x′[1] ≈ σ((1 - 0.3) / 1.1449489742783179) + y = m(x).data + @test isapprox(y, data(sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ))), atol = 1.0e-7) end let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1)) @@ -85,7 +88,7 @@ end end let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1)) - y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) + y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) @test m(x) == y end diff --git a/test/runtests.jl b/test/runtests.jl index 1c7c461d..ef7ed208 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,7 +13,7 @@ if Base.JLOptions().check_bounds == 1 exit() end -using Flux, Test, Random +using Flux, Test, Random, Statistics using Random Random.seed!(0)