Merge pull request #294 from avik-pal/cudnn_batchnorm
Wrapper for CuDNN BatchNorm
This commit is contained in:
commit
dd154ca049
@ -7,6 +7,12 @@ if !applicable(CuArray{UInt8}, undef, 1)
|
|||||||
end
|
end
|
||||||
|
|
||||||
if CuArrays.libcudnn != nothing
|
if CuArrays.libcudnn != nothing
|
||||||
|
if isdefined(CuArrays, :libcudnn_handle)
|
||||||
|
handle() = CuArrays.libcudnn_handle[]
|
||||||
|
else
|
||||||
|
handle() = CuArrays.CUDNN.handle()
|
||||||
|
end
|
||||||
|
include("curnn.jl")
|
||||||
include("cudnn.jl")
|
include("cudnn.jl")
|
||||||
else
|
else
|
||||||
@warn("CUDNN is not installed, some functionality will not be available.")
|
@warn("CUDNN is not installed, some functionality will not be available.")
|
||||||
|
@ -1,13 +1,8 @@
|
|||||||
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t,
|
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||||
cudnnDataType, TensorDesc, FilterDesc
|
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
|
||||||
|
import ..Flux: data
|
||||||
using LinearAlgebra
|
using LinearAlgebra
|
||||||
|
|
||||||
if isdefined(CuArrays, :libcudnn_handle)
|
|
||||||
handle() = CuArrays.libcudnn_handle[]
|
|
||||||
else
|
|
||||||
handle() = CuArrays.CUDNN.handle()
|
|
||||||
end
|
|
||||||
|
|
||||||
mutable struct DropoutDesc
|
mutable struct DropoutDesc
|
||||||
ptr::Ptr{Nothing}
|
ptr::Ptr{Nothing}
|
||||||
states::CuVector{UInt8}
|
states::CuVector{UInt8}
|
||||||
@ -30,324 +25,204 @@ function DropoutDesc(ρ::Real; seed::Integer=0)
|
|||||||
return desc
|
return desc
|
||||||
end
|
end
|
||||||
|
|
||||||
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
const BATCHNORM_SPATIAL = 1
|
||||||
const RNN_TANH = 1 # Stock RNN with tanh activation
|
const BATCHNORM_ACTIVATION = 0
|
||||||
const LSTM = 2 # LSTM with no peephole connections
|
const BATCHNORM_MIN_EPS = 1e-5
|
||||||
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
|
|
||||||
|
|
||||||
const LINEAR_INPUT = 0
|
@inline _wsize(y) = (map(_ -> 1, size(y)[1:end-2])..., size(y)[end-1], 1)
|
||||||
const SKIP_INPUT = 1
|
|
||||||
|
|
||||||
const UNIDIRECTIONAL = 0
|
@inline _reddims(y) = (collect(1:ndims(y)-2)..., ndims(y))
|
||||||
const BIDIRECTIONAL = 1
|
|
||||||
|
|
||||||
const RNN_ALGO_STANDARD = 0
|
mutable struct BNCache
|
||||||
const RNN_ALGO_PERSIST_STATIC = 1
|
mean
|
||||||
const RNN_ALGO_PERSIST_DYNAMIC = 2
|
ivar
|
||||||
|
|
||||||
# param layout:
|
|
||||||
# RNN: [weight, bias] × [input, hidden]
|
|
||||||
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
|
|
||||||
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
|
||||||
|
|
||||||
function params(w::CuVector, input, hidden, n = 1)
|
|
||||||
slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape)
|
|
||||||
wx = slice(0, (input, hidden*n))
|
|
||||||
wh = slice(length(wx), (hidden, hidden*n))
|
|
||||||
bias = view(w, length(wx)+length(wh) .+ (1:hidden*n))
|
|
||||||
(wx, wh), bias
|
|
||||||
end
|
end
|
||||||
|
|
||||||
mutable struct RNNDesc{T}
|
BNCache() = BNCache(nothing, nothing)
|
||||||
mode::Int
|
|
||||||
input::Int
|
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
|
||||||
hidden::Int
|
# so reshape a 2D Tensor into 4D
|
||||||
params::CuVector{T}
|
batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2},
|
||||||
weights::NTuple{2,CuMatrix{T}}
|
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||||
bias::CuVector{T}
|
cache = nothing, alpha = T(1), beta = T(0),
|
||||||
ptr::Ptr{Nothing}
|
eps = T(1e-5), training = true) where T<:Union{Float32, Float64} =
|
||||||
|
dropdims(batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), running_mean, running_var, momentum,
|
||||||
|
cache = cache, alpha = alpha, beta = beta, eps = eps, training = training), dims = (1, 2))
|
||||||
|
|
||||||
|
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::Union{CuArray{T, 4},CuArray{T,5}},
|
||||||
|
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||||
|
cache = nothing, alpha = T(1), beta = T(0),
|
||||||
|
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
||||||
|
y = similar(x)
|
||||||
|
cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache,
|
||||||
|
alpha = alpha, beta = beta, eps = eps, training = training)
|
||||||
|
y
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
|
function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
|
||||||
|
running_mean::CuArray{T}, running_var::CuArray{T},
|
||||||
function rnnParamSize(T, r, input)
|
momentum; cache = nothing,
|
||||||
size = Csize_t[0]
|
alpha = T(1), beta = T(0),
|
||||||
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint),
|
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
||||||
handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
|
dims = _wsize(x)
|
||||||
return Int(size[])÷sizeof(T)
|
if eps < BATCHNORM_MIN_EPS
|
||||||
end
|
# warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS)
|
||||||
|
eps = BATCHNORM_MIN_EPS
|
||||||
ngates(mode) = [1, 1, 4, 3][mode+1]
|
|
||||||
ngates(r::RNNDesc) = ngates(r.mode)
|
|
||||||
|
|
||||||
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
|
||||||
d = [C_NULL]
|
|
||||||
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d)
|
|
||||||
|
|
||||||
dropoutDesc = DropoutDesc(0)
|
|
||||||
inputMode = LINEAR_INPUT
|
|
||||||
direction = UNIDIRECTIONAL
|
|
||||||
algo = RNN_ALGO_STANDARD
|
|
||||||
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
|
|
||||||
handle(),d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
|
||||||
|
|
||||||
w = cuzeros(T, rnnParamSize(T, d[], input))
|
|
||||||
# TODO: avoid reserve allocation here
|
|
||||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
|
||||||
finalizer(rd) do x
|
|
||||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
|
||||||
end
|
end
|
||||||
return rd
|
xd = TensorDesc(x)
|
||||||
end
|
yd = TensorDesc(y)
|
||||||
|
gd = TensorDesc(T, dims)
|
||||||
|
|
||||||
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
|
if training
|
||||||
size = Csize_t[0]
|
|
||||||
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}),
|
|
||||||
handle(), r, seqlen, xdesc, size)
|
|
||||||
return Int(size[])
|
|
||||||
end
|
|
||||||
|
|
||||||
const workspace = [CuVector{UInt8}(undef, 1)]
|
if cache !== nothing
|
||||||
|
mean = zeros(CuArray{T}, dims...)
|
||||||
|
ivar = ones(CuArray{T}, dims...)
|
||||||
|
else
|
||||||
|
mean = C_NULL
|
||||||
|
ivar = C_NULL
|
||||||
|
end
|
||||||
|
|
||||||
getworkspace(bytes) =
|
@check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t,
|
||||||
length(workspace[]) ≥ bytes ?
|
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||||
workspace[] :
|
Ptr{T}, Ptr{T},
|
||||||
(workspace[] = CuVector{UInt8}(undef, bytes))
|
|
||||||
|
|
||||||
getworkspace(r::RNNDesc, seqlen, xdesc) =
|
|
||||||
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
|
|
||||||
|
|
||||||
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
|
||||||
size = Csize_t[0]
|
|
||||||
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}),
|
|
||||||
handle(), r, seqlen, xdesc, size)
|
|
||||||
return Int(size[])
|
|
||||||
end
|
|
||||||
|
|
||||||
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
|
||||||
workspace, reserve=nothing) where T
|
|
||||||
if reserve == nothing
|
|
||||||
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
|
||||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
|
||||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
|
||||||
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
|
||||||
Ptr{Nothing}, Ptr{T},
|
Ptr{Nothing}, Ptr{T},
|
||||||
Ptr{Nothing}, Csize_t),
|
Ptr{Nothing}, Ptr{T},
|
||||||
handle(), rnn, seqlen,
|
Ptr{Nothing}, Ptr{T}, Ptr{T},
|
||||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
Cdouble, Ptr{T}, Ptr{T},
|
||||||
workspace, length(workspace))
|
Cdouble, Ptr{T}, Ptr{T}),
|
||||||
|
handle(), BATCHNORM_SPATIAL,
|
||||||
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
|
xd, x,
|
||||||
|
yd, y,
|
||||||
|
gd, g, b,
|
||||||
|
momentum, running_mean, running_var,
|
||||||
|
eps, mean, ivar)
|
||||||
|
|
||||||
|
if cache !== nothing
|
||||||
|
cache.mean = mean
|
||||||
|
cache.ivar = ivar
|
||||||
|
end
|
||||||
else
|
else
|
||||||
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
|
||||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
(Ptr{cudnnHandle_t},cudnnBatchNormMode_t,
|
||||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
Ptr{T}, Ptr{T},
|
||||||
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
Ptr{Nothing}, Ptr{T},
|
||||||
handle(), rnn, seqlen,
|
Ptr{Nothing}, Ptr{T},
|
||||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
Ptr{Nothing}, Ptr{T}, Ptr{T},
|
||||||
workspace, length(workspace), reserve, length(reserve))
|
Ptr{T}, Ptr{T},
|
||||||
|
Cdouble),
|
||||||
|
handle(), BATCHNORM_SPATIAL,
|
||||||
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
|
xd, x,
|
||||||
|
yd, y,
|
||||||
|
gd, g, b,
|
||||||
|
running_mean, running_var,
|
||||||
|
eps)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T, 2},
|
||||||
|
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||||
hDesc(h::Nothing) = C_NULL, C_NULL
|
cache = nothing, eps = T(1e-5), alpha = T(1),
|
||||||
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
beta = T(0), training = true) where T<:Union{Float32, Float64}
|
||||||
function hDesc(h::CuArray)
|
dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1),
|
||||||
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
size(dy, 2)), running_mean, running_var, momentum, cache = cache, eps = eps,
|
||||||
|
alpha = alpha, beta = beta, training = training)
|
||||||
|
(dg, db, dropdims(dx, dims = (1, 2)))
|
||||||
end
|
end
|
||||||
|
|
||||||
# TODO: can we just manipulate strides here?
|
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
||||||
# TODO: should use repmat, but this isn't implemented.
|
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||||
hBatch(x::AbstractVector, h::CuVector) = h
|
cache = nothing, eps = T(1e-5), alpha = T(1),
|
||||||
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
|
beta = T(0), training = true) where T<:Union{Float32, Float64}
|
||||||
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
|
dg = similar(g)
|
||||||
|
db = similar(b)
|
||||||
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
|
dx = similar(x)
|
||||||
h = hBatch(x, h_)
|
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum),
|
||||||
c = c_ == nothing ? nothing : hBatch(x, c_)
|
training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
|
||||||
@assert size(x, 1) == rnn.input
|
(dg, db, dx)
|
||||||
@assert size(h, 1) == rnn.hidden
|
|
||||||
@assert size(x, 2) == size(h, 2)
|
|
||||||
seqLength = 1
|
|
||||||
xdesc = xDesc(x)
|
|
||||||
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
|
|
||||||
ho = similar(h)
|
|
||||||
ydesc = xDesc(y)
|
|
||||||
workspace = getworkspace(rnn, seqLength, xdesc)
|
|
||||||
reserve = train == Val{true} ?
|
|
||||||
CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
|
|
||||||
nothing
|
|
||||||
co = c == nothing ? c : similar(c)
|
|
||||||
cudnnRNNForward(rnn, seqLength,
|
|
||||||
xdesc, x,
|
|
||||||
hDesc(h)...,
|
|
||||||
hDesc(c)...,
|
|
||||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
|
||||||
ydesc, y,
|
|
||||||
hDesc(ho)...,
|
|
||||||
hDesc(co)...,
|
|
||||||
workspace, reserve)
|
|
||||||
result = c == nothing ? (y, ho) : (y, ho, co)
|
|
||||||
return train == Val{true} ? (reserve, result) : result
|
|
||||||
end
|
end
|
||||||
|
|
||||||
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
|
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
||||||
forward(rnn, x, h, c, Val{true})
|
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
||||||
|
running_mean::CuArray{T}, running_var::CuArray{T},
|
||||||
|
momentum; cache = nothing, eps = T(1e-5),
|
||||||
|
alpha = T(1), beta = T(0),
|
||||||
|
dalpha = T(1), dbeta = T(0), training = true) where T<:Union{Float32, Float64}
|
||||||
|
if training
|
||||||
|
xd = TensorDesc(x)
|
||||||
|
dyd = TensorDesc(dy)
|
||||||
|
dxd = TensorDesc(dx)
|
||||||
|
gd = TensorDesc(T, _wsize(x))
|
||||||
|
if cache !== nothing
|
||||||
|
mean, ivar = cache.mean, cache.ivar
|
||||||
|
info("mean and ivar are fetched from the cache")
|
||||||
|
else
|
||||||
|
mean, ivar = C_NULL, C_NULL
|
||||||
|
end
|
||||||
|
|
||||||
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
if eps < BATCHNORM_MIN_EPS
|
||||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
eps = BATCHNORM_MIN_EPS
|
||||||
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
end
|
||||||
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
|
||||||
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
|
||||||
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
|
|
||||||
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
|
||||||
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
|
||||||
handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
|
||||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
|
||||||
end
|
|
||||||
|
|
||||||
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
@check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t,
|
||||||
# Same as above, any more efficient way?
|
(cudnnHandle_t,cudnnBatchNormMode_t,
|
||||||
dy = dy_ isa Integer ? zero(y) : dy_
|
Ptr{T}, Ptr{T},
|
||||||
yd = xDesc(y)
|
Ptr{T}, Ptr{T},
|
||||||
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
Ptr{Nothing}, Ptr{T},
|
||||||
dh = similar(h)
|
Ptr{Nothing}, Ptr{T},
|
||||||
dc = c == nothing ? nothing : similar(c)
|
Ptr{Nothing}, Ptr{T},
|
||||||
cudnnRNNBackwardData(rnn, 1,
|
Ptr{Nothing}, Ptr{T}, Ptr{T}, Ptr{T},
|
||||||
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
|
Cdouble, Ptr{T}, Ptr{T}),
|
||||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
handle(), BATCHNORM_SPATIAL,
|
||||||
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
workspace[], reserve)
|
Ref(T(dalpha)), Ref(T(dbeta)),
|
||||||
return c == nothing ? (dx, dh) : (dx, dh, dc)
|
xd, x,
|
||||||
end
|
dyd, dy,
|
||||||
|
dxd, dx,
|
||||||
backwardData(rnn, y, dy, dho, hx, reserve) =
|
gd, g, dg, db,
|
||||||
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
|
eps, mean, ivar)
|
||||||
|
else
|
||||||
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
|
ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps)
|
||||||
workspace, reserve) where T
|
dx .= dy .* reshape(g, _wsize(x)) .* ivar
|
||||||
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
dg .= squeeze(sum(dy .* (x .- reshape(running_mean, _wsize(x))) .* ivar, _reddims(dy)), dims = (1,2,4))
|
||||||
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
|
db .= squeeze(sum(dy, _reddims(dy)), dims = (1,2,4))
|
||||||
Ptr{Ptr{Nothing}}, Ptr{T}, #x
|
|
||||||
Ptr{Nothing}, Ptr{T}, #hx
|
|
||||||
Ptr{Ptr{Nothing}}, Ptr{T}, #y
|
|
||||||
Ptr{Nothing}, Csize_t, #ws
|
|
||||||
Ptr{Nothing}, Ptr{T}, #dw
|
|
||||||
Ptr{Nothing}, Csize_t), #rs
|
|
||||||
handle(), rnn, seqlen, xd, x, hd, h, yd, y,
|
|
||||||
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
|
||||||
end
|
|
||||||
|
|
||||||
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
|
|
||||||
dw = zero(rnn.params)
|
|
||||||
cudnnRNNBackwardWeights(rnn, 1,
|
|
||||||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
|
||||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
|
||||||
workspace[], reserve)
|
|
||||||
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
|
|
||||||
end
|
|
||||||
|
|
||||||
# Interface
|
|
||||||
|
|
||||||
import ..Flux: Flux, relu
|
|
||||||
import ..Tracker: TrackedArray
|
|
||||||
using .CuArrays.CUDAnative
|
|
||||||
using .CuArrays: @cuindex, cudims
|
|
||||||
|
|
||||||
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
|
||||||
function kernel(dst, src)
|
|
||||||
I = @cuindex dst
|
|
||||||
dst[I...] = src[reverse(I)...]
|
|
||||||
return
|
|
||||||
end
|
|
||||||
blk, thr = cudims(dst)
|
|
||||||
@cuda blocks=blk threads=thr kernel(dst, src)
|
|
||||||
return dst
|
|
||||||
end
|
|
||||||
|
|
||||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
|
||||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
|
|
||||||
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
|
|
||||||
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
|
|
||||||
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
|
||||||
|
|
||||||
function copyparams!(m::CuRNNs, d::RNNDesc)
|
|
||||||
Wi, Wh = d.weights
|
|
||||||
copy_transpose!(Wi, Flux.data(m.Wi))
|
|
||||||
copy_transpose!(Wh, Flux.data(m.Wh))
|
|
||||||
copy_transpose!(d.bias, Flux.data(m.b))
|
|
||||||
return
|
|
||||||
end
|
|
||||||
|
|
||||||
function RNNDesc(m::CuRNNs{T}) where T
|
|
||||||
h, i = length(m.h), size(m.Wi, 2)
|
|
||||||
mode = m isa CuRNN ?
|
|
||||||
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
|
|
||||||
m isa CuGRU ? GRU : LSTM
|
|
||||||
r = RNNDesc{T}(mode, i, h)
|
|
||||||
return r
|
|
||||||
end
|
|
||||||
|
|
||||||
const descs = WeakKeyDict()
|
|
||||||
|
|
||||||
function desc(rnn)
|
|
||||||
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
|
|
||||||
copyparams!(rnn, d)
|
|
||||||
return d
|
|
||||||
end
|
|
||||||
|
|
||||||
import Flux.Tracker
|
|
||||||
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
|
|
||||||
|
|
||||||
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
|
||||||
|
|
||||||
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
|
||||||
result = istrain(m, h, x) ?
|
|
||||||
track(m, x, h, m.Wi, m.Wh, m.b) :
|
|
||||||
forward(desc(m), x, h)
|
|
||||||
return result[2], result[1]
|
|
||||||
end
|
|
||||||
|
|
||||||
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
|
||||||
result = istrain(m, h, x) ?
|
|
||||||
track(m, x, h, m.Wi, m.Wh, m.b) :
|
|
||||||
forward(desc(m), x, h)
|
|
||||||
return result[2], result[1]
|
|
||||||
end
|
|
||||||
|
|
||||||
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
|
||||||
result = istrain(m, h, x) ?
|
|
||||||
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
|
|
||||||
forward(desc(m), x, h[1], h[2])
|
|
||||||
return (result[2], result[3]), result[1]
|
|
||||||
end
|
|
||||||
|
|
||||||
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
|
||||||
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
|
||||||
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
|
||||||
|
|
||||||
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
|
|
||||||
reserve, result = forwardTrain(desc(m), data(x), data(h))
|
|
||||||
result, function (Δ)
|
|
||||||
y, ho = result
|
|
||||||
dy, dho = Δ
|
|
||||||
h_ = hBatch(x, data(h))
|
|
||||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
|
||||||
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
|
# Flux Interface
|
||||||
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
|
|
||||||
result, function (Δ)
|
(BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
|
||||||
y, ho = result
|
batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)
|
||||||
dy, dho, dco = Δ
|
|
||||||
h_ = hBatch(x, data(h))
|
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
||||||
c_ = hBatch(x, data(c))
|
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||||
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
|
||||||
nobacksies(:RNN,
|
batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
||||||
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
|
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||||
transpose(dWi), transpose(dWh), db))
|
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||||
end
|
|
||||||
end
|
batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
||||||
|
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||||
|
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||||
|
|
||||||
|
batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
||||||
|
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||||
|
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||||
|
|
||||||
|
batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
||||||
|
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||||
|
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||||
|
|
||||||
|
batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T},
|
||||||
|
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||||
|
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||||
|
|
||||||
|
batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
||||||
|
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||||
|
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||||
|
|
||||||
|
@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
||||||
|
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)
|
||||||
|
325
src/cuda/curnn.jl
Normal file
325
src/cuda/curnn.jl
Normal file
@ -0,0 +1,325 @@
|
|||||||
|
using .CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||||
|
cudnnBatchNormMode_t, cudnnHandle_t, cudnnDataType, TensorDesc, FilterDesc
|
||||||
|
using LinearAlgebra
|
||||||
|
|
||||||
|
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
||||||
|
const RNN_TANH = 1 # Stock RNN with tanh activation
|
||||||
|
const LSTM = 2 # LSTM with no peephole connections
|
||||||
|
const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1)
|
||||||
|
|
||||||
|
const LINEAR_INPUT = 0
|
||||||
|
const SKIP_INPUT = 1
|
||||||
|
|
||||||
|
const UNIDIRECTIONAL = 0
|
||||||
|
const BIDIRECTIONAL = 1
|
||||||
|
|
||||||
|
const RNN_ALGO_STANDARD = 0
|
||||||
|
const RNN_ALGO_PERSIST_STATIC = 1
|
||||||
|
const RNN_ALGO_PERSIST_DYNAMIC = 2
|
||||||
|
|
||||||
|
# param layout:
|
||||||
|
# RNN: [weight, bias] × [input, hidden]
|
||||||
|
# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem]
|
||||||
|
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
||||||
|
|
||||||
|
function params(w::CuVector, input, hidden, n = 1)
|
||||||
|
slice(offset, shape) = reshape(view(w, offset.+(1:prod(shape))), shape)
|
||||||
|
wx = slice(0, (input, hidden*n))
|
||||||
|
wh = slice(length(wx), (hidden, hidden*n))
|
||||||
|
bias = view(w, length(wx)+length(wh) .+ (1:hidden*n))
|
||||||
|
(wx, wh), bias
|
||||||
|
end
|
||||||
|
|
||||||
|
mutable struct RNNDesc{T}
|
||||||
|
mode::Int
|
||||||
|
input::Int
|
||||||
|
hidden::Int
|
||||||
|
params::CuVector{T}
|
||||||
|
weights::NTuple{2,CuMatrix{T}}
|
||||||
|
bias::CuVector{T}
|
||||||
|
ptr::Ptr{Nothing}
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.unsafe_convert(::Type{Ptr{Nothing}}, d::RNNDesc) = d.ptr
|
||||||
|
|
||||||
|
function rnnParamSize(T, r, input)
|
||||||
|
size = Csize_t[0]
|
||||||
|
@check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Ptr{Nothing},Ptr{Csize_t},Cint),
|
||||||
|
handle(), r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T))
|
||||||
|
return Int(size[])÷sizeof(T)
|
||||||
|
end
|
||||||
|
|
||||||
|
ngates(mode) = [1, 1, 4, 3][mode+1]
|
||||||
|
ngates(r::RNNDesc) = ngates(r.mode)
|
||||||
|
|
||||||
|
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
||||||
|
d = [C_NULL]
|
||||||
|
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Nothing}},),d)
|
||||||
|
|
||||||
|
dropoutDesc = DropoutDesc(0)
|
||||||
|
inputMode = LINEAR_INPUT
|
||||||
|
direction = UNIDIRECTIONAL
|
||||||
|
algo = RNN_ALGO_STANDARD
|
||||||
|
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Cint,Ptr{Nothing},Cint,Cint,Cint,Cint,Cint),
|
||||||
|
handle(),d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
||||||
|
|
||||||
|
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||||
|
# TODO: avoid reserve allocation here
|
||||||
|
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
||||||
|
finalizer(rd) do x
|
||||||
|
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||||
|
end
|
||||||
|
return rd
|
||||||
|
end
|
||||||
|
|
||||||
|
function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
|
||||||
|
size = Csize_t[0]
|
||||||
|
@check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Nothing},Ptr{Nothing},Cint,Ptr{Ptr{Nothing}},Ptr{Csize_t}),
|
||||||
|
handle(), r, seqlen, xdesc, size)
|
||||||
|
return Int(size[])
|
||||||
|
end
|
||||||
|
|
||||||
|
const workspace = [CuVector{UInt8}(undef, 1)]
|
||||||
|
|
||||||
|
getworkspace(bytes) =
|
||||||
|
length(workspace[]) ≥ bytes ?
|
||||||
|
workspace[] :
|
||||||
|
(workspace[] = CuVector{UInt8}(undef, bytes))
|
||||||
|
|
||||||
|
getworkspace(r::RNNDesc, seqlen, xdesc) =
|
||||||
|
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
|
||||||
|
|
||||||
|
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
||||||
|
size = Csize_t[0]
|
||||||
|
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Nothing}, Ptr{Nothing}, Cint, Ptr{Ptr{Nothing}}, Ptr{Csize_t}),
|
||||||
|
handle(), r, seqlen, xdesc, size)
|
||||||
|
return Int(size[])
|
||||||
|
end
|
||||||
|
|
||||||
|
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||||
|
workspace, reserve=nothing) where T
|
||||||
|
if reserve == nothing
|
||||||
|
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
||||||
|
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||||
|
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||||
|
Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||||
|
Ptr{Nothing}, Ptr{T},
|
||||||
|
Ptr{Nothing}, Csize_t),
|
||||||
|
handle(), rnn, seqlen,
|
||||||
|
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||||
|
workspace, length(workspace))
|
||||||
|
else
|
||||||
|
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
||||||
|
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||||
|
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||||
|
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
||||||
|
handle(), rnn, seqlen,
|
||||||
|
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||||
|
workspace, length(workspace), reserve, length(reserve))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
||||||
|
|
||||||
|
hDesc(h::Nothing) = C_NULL, C_NULL
|
||||||
|
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
||||||
|
function hDesc(h::CuArray)
|
||||||
|
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
||||||
|
end
|
||||||
|
|
||||||
|
# TODO: can we just manipulate strides here?
|
||||||
|
# TODO: should use repmat, but this isn't implemented.
|
||||||
|
hBatch(x::AbstractVector, h::CuVector) = h
|
||||||
|
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
|
||||||
|
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
|
||||||
|
|
||||||
|
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
|
||||||
|
h = hBatch(x, h_)
|
||||||
|
c = c_ == nothing ? nothing : hBatch(x, c_)
|
||||||
|
@assert size(x, 1) == rnn.input
|
||||||
|
@assert size(h, 1) == rnn.hidden
|
||||||
|
@assert size(x, 2) == size(h, 2)
|
||||||
|
seqLength = 1
|
||||||
|
xdesc = xDesc(x)
|
||||||
|
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
|
||||||
|
ho = similar(h)
|
||||||
|
ydesc = xDesc(y)
|
||||||
|
workspace = getworkspace(rnn, seqLength, xdesc)
|
||||||
|
reserve = train == Val{true} ?
|
||||||
|
CuVector{UInt8}(undef, rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
|
||||||
|
nothing
|
||||||
|
co = c == nothing ? c : similar(c)
|
||||||
|
cudnnRNNForward(rnn, seqLength,
|
||||||
|
xdesc, x,
|
||||||
|
hDesc(h)...,
|
||||||
|
hDesc(c)...,
|
||||||
|
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||||
|
ydesc, y,
|
||||||
|
hDesc(ho)...,
|
||||||
|
hDesc(co)...,
|
||||||
|
workspace, reserve)
|
||||||
|
result = c == nothing ? (y, ho) : (y, ho, co)
|
||||||
|
return train == Val{true} ? (reserve, result) : result
|
||||||
|
end
|
||||||
|
|
||||||
|
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
|
||||||
|
forward(rnn, x, h, c, Val{true})
|
||||||
|
|
||||||
|
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||||
|
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
||||||
|
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
||||||
|
(Ptr{Nothing}, Ptr{Nothing}, Cint,
|
||||||
|
Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||||
|
Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing},
|
||||||
|
Ptr{T}, Ptr{Ptr{Nothing}}, Ptr{T}, Ptr{Nothing}, Ptr{T}, Ptr{Nothing}, Ptr{T},
|
||||||
|
Ptr{Nothing}, Csize_t, Ptr{Nothing}, Csize_t),
|
||||||
|
handle(), rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||||
|
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
||||||
|
end
|
||||||
|
|
||||||
|
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
||||||
|
# Same as above, any more efficient way?
|
||||||
|
dy = dy_ isa Integer ? zero(y) : dy_
|
||||||
|
yd = xDesc(y)
|
||||||
|
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
||||||
|
dh = similar(h)
|
||||||
|
dc = c == nothing ? nothing : similar(c)
|
||||||
|
cudnnRNNBackwardData(rnn, 1,
|
||||||
|
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
|
||||||
|
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||||
|
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
|
||||||
|
workspace[], reserve)
|
||||||
|
return c == nothing ? (dx, dh) : (dx, dh, dc)
|
||||||
|
end
|
||||||
|
|
||||||
|
backwardData(rnn, y, dy, dho, hx, reserve) =
|
||||||
|
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
|
||||||
|
|
||||||
|
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
|
||||||
|
workspace, reserve) where T
|
||||||
|
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
||||||
|
(Ptr{Nothing}, Ptr{Nothing}, Cint, # handle, rnnDesc, seqLength
|
||||||
|
Ptr{Ptr{Nothing}}, Ptr{T}, #x
|
||||||
|
Ptr{Nothing}, Ptr{T}, #hx
|
||||||
|
Ptr{Ptr{Nothing}}, Ptr{T}, #y
|
||||||
|
Ptr{Nothing}, Csize_t, #ws
|
||||||
|
Ptr{Nothing}, Ptr{T}, #dw
|
||||||
|
Ptr{Nothing}, Csize_t), #rs
|
||||||
|
handle(), rnn, seqlen, xd, x, hd, h, yd, y,
|
||||||
|
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
||||||
|
end
|
||||||
|
|
||||||
|
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
|
||||||
|
dw = zero(rnn.params)
|
||||||
|
cudnnRNNBackwardWeights(rnn, 1,
|
||||||
|
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
||||||
|
FilterDesc(T, (1, 1, length(dw))), dw,
|
||||||
|
workspace[], reserve)
|
||||||
|
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
|
||||||
|
end
|
||||||
|
|
||||||
|
# Interface
|
||||||
|
|
||||||
|
import ..Flux: Flux, relu
|
||||||
|
import ..Tracker: TrackedArray
|
||||||
|
using .CuArrays.CUDAnative
|
||||||
|
using .CuArrays: @cuindex, cudims
|
||||||
|
|
||||||
|
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
||||||
|
function kernel(dst, src)
|
||||||
|
I = @cuindex dst
|
||||||
|
dst[I...] = src[reverse(I)...]
|
||||||
|
return
|
||||||
|
end
|
||||||
|
blk, thr = cudims(dst)
|
||||||
|
@cuda blocks=blk threads=thr kernel(dst, src)
|
||||||
|
return dst
|
||||||
|
end
|
||||||
|
|
||||||
|
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
||||||
|
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
|
||||||
|
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||||
|
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||||
|
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
||||||
|
|
||||||
|
function copyparams!(m::CuRNNs, d::RNNDesc)
|
||||||
|
Wi, Wh = d.weights
|
||||||
|
copy_transpose!(Wi, Flux.data(m.Wi))
|
||||||
|
copy_transpose!(Wh, Flux.data(m.Wh))
|
||||||
|
copy_transpose!(d.bias, Flux.data(m.b))
|
||||||
|
return
|
||||||
|
end
|
||||||
|
|
||||||
|
function RNNDesc(m::CuRNNs{T}) where T
|
||||||
|
h, i = length(m.h), size(m.Wi, 2)
|
||||||
|
mode = m isa CuRNN ?
|
||||||
|
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
|
||||||
|
m isa CuGRU ? GRU : LSTM
|
||||||
|
r = RNNDesc{T}(mode, i, h)
|
||||||
|
return r
|
||||||
|
end
|
||||||
|
|
||||||
|
const descs = WeakKeyDict()
|
||||||
|
|
||||||
|
function desc(rnn)
|
||||||
|
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn))
|
||||||
|
copyparams!(rnn, d)
|
||||||
|
return d
|
||||||
|
end
|
||||||
|
|
||||||
|
import Flux.Tracker
|
||||||
|
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
|
||||||
|
|
||||||
|
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
||||||
|
|
||||||
|
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||||
|
result = istrain(m, h, x) ?
|
||||||
|
track(m, x, h, m.Wi, m.Wh, m.b) :
|
||||||
|
forward(desc(m), x, h)
|
||||||
|
return result[2], result[1]
|
||||||
|
end
|
||||||
|
|
||||||
|
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||||
|
result = istrain(m, h, x) ?
|
||||||
|
track(m, x, h, m.Wi, m.Wh, m.b) :
|
||||||
|
forward(desc(m), x, h)
|
||||||
|
return result[2], result[1]
|
||||||
|
end
|
||||||
|
|
||||||
|
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||||
|
result = istrain(m, h, x) ?
|
||||||
|
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
|
||||||
|
forward(desc(m), x, h[1], h[2])
|
||||||
|
return (result[2], result[3]), result[1]
|
||||||
|
end
|
||||||
|
|
||||||
|
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
|
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
|
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
|
|
||||||
|
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
|
||||||
|
reserve, result = forwardTrain(desc(m), data(x), data(h))
|
||||||
|
result, function (Δ)
|
||||||
|
y, ho = result
|
||||||
|
dy, dho = Δ
|
||||||
|
h_ = hBatch(x, data(h))
|
||||||
|
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||||
|
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||||
|
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
|
||||||
|
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
|
||||||
|
result, function (Δ)
|
||||||
|
y, ho = result
|
||||||
|
dy, dho, dco = Δ
|
||||||
|
h_ = hBatch(x, data(h))
|
||||||
|
c_ = hBatch(x, data(c))
|
||||||
|
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
||||||
|
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||||
|
nobacksies(:RNN,
|
||||||
|
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
|
||||||
|
transpose(dWi), transpose(dWh), db))
|
||||||
|
end
|
||||||
|
end
|
@ -44,7 +44,6 @@ end
|
|||||||
_testmode!(a::Dropout, test) = (a.active = !test)
|
_testmode!(a::Dropout, test) = (a.active = !test)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
LayerNorm(h::Integer)
|
LayerNorm(h::Integer)
|
||||||
|
|
||||||
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
|
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
|
||||||
@ -86,7 +85,6 @@ See [Batch Normalization: Accelerating Deep Network Training by Reducing
|
|||||||
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
|
Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf).
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
m = Chain(
|
m = Chain(
|
||||||
Dense(28^2, 64),
|
Dense(28^2, 64),
|
||||||
@ -101,14 +99,14 @@ mutable struct BatchNorm{F,V,W,N}
|
|||||||
β::V # bias
|
β::V # bias
|
||||||
γ::V # scale
|
γ::V # scale
|
||||||
μ::W # moving mean
|
μ::W # moving mean
|
||||||
σ::W # moving std
|
σ²::W # moving std
|
||||||
ϵ::N
|
ϵ::N
|
||||||
momentum::N
|
momentum::N
|
||||||
active::Bool
|
active::Bool
|
||||||
end
|
end
|
||||||
|
|
||||||
BatchNorm(chs::Integer, λ = identity;
|
BatchNorm(chs::Integer, λ = identity;
|
||||||
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-8, momentum = .1) =
|
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-5, momentum = .1) =
|
||||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||||
|
|
||||||
@ -124,31 +122,31 @@ function (BN::BatchNorm)(x)
|
|||||||
|
|
||||||
if !BN.active
|
if !BN.active
|
||||||
μ = reshape(BN.μ, affine_shape...)
|
μ = reshape(BN.μ, affine_shape...)
|
||||||
σ = reshape(BN.σ, affine_shape...)
|
σ² = reshape(BN.σ², affine_shape...)
|
||||||
else
|
else
|
||||||
T = eltype(x)
|
T = eltype(x)
|
||||||
|
|
||||||
ϵ = data(convert(T, BN.ϵ))
|
ϵ = data(convert(T, BN.ϵ))
|
||||||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||||
μ = mean(x, dims = axes)
|
μ = mean(x, dims = axes)
|
||||||
σ = sqrt.(mean((x .- μ).^2, dims = axes) .+ ϵ)
|
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
|
||||||
|
|
||||||
# update moving mean/std
|
# update moving mean/std
|
||||||
mtm = data(convert(T, BN.momentum))
|
mtm = data(convert(T, BN.momentum))
|
||||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* dropdims(data(μ), dims = (axes...,))
|
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
|
||||||
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* dropdims(data(σ), dims = (axes...,)) .* m ./ (m - 1)
|
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1))
|
||||||
end
|
end
|
||||||
|
|
||||||
let λ = BN.λ
|
let λ = BN.λ
|
||||||
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
|
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ)) .+ reshape(β, affine_shape...))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
children(BN::BatchNorm) =
|
children(BN::BatchNorm) =
|
||||||
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active)
|
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum, BN.active)
|
||||||
|
|
||||||
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
|
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
|
||||||
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active)
|
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum, BN.active)
|
||||||
|
|
||||||
_testmode!(BN::BatchNorm, test) = (BN.active = !test)
|
_testmode!(BN::BatchNorm, test) = (BN.active = !test)
|
||||||
|
|
||||||
|
@ -36,4 +36,8 @@ Flux.back!(sum(l))
|
|||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
CuArrays.libcudnn != nothing && include("cudnn.jl")
|
if CuArrays.libcudnn != nothing
|
||||||
|
@info "Testing Flux/CUDNN"
|
||||||
|
include("cudnn.jl")
|
||||||
|
include("curnn.jl")
|
||||||
|
end
|
||||||
|
@ -1,48 +1,48 @@
|
|||||||
using Flux, CuArrays, Test
|
using Flux, Flux.Tracker, CuArrays, Test
|
||||||
|
using Flux.Tracker: TrackedArray, data
|
||||||
|
|
||||||
@info "Testing Flux/CUDNN"
|
@testset "CUDNN BatchNorm" begin
|
||||||
|
@testset "4D Input" begin
|
||||||
|
x = TrackedArray(Float64.(collect(reshape(1:12, 2, 2, 3, 1))))
|
||||||
|
m = BatchNorm(3)
|
||||||
|
cx = gpu(x)
|
||||||
|
cm = gpu(m)
|
||||||
|
|
||||||
@testset "RNN" begin
|
y = m(x)
|
||||||
@testset for R in [RNN, GRU, LSTM]
|
cy = cm(cx)
|
||||||
rnn = R(10, 5)
|
|
||||||
curnn = mapleaves(gpu, rnn)
|
|
||||||
@testset for batch_size in (1, 5)
|
|
||||||
Flux.reset!(rnn)
|
|
||||||
Flux.reset!(curnn)
|
|
||||||
x = batch_size == 1 ?
|
|
||||||
param(rand(10)) :
|
|
||||||
param(rand(10,batch_size))
|
|
||||||
cux = gpu(x)
|
|
||||||
y = (rnn(x); rnn(x))
|
|
||||||
cuy = (curnn(cux); curnn(cux))
|
|
||||||
|
|
||||||
@test y.data ≈ collect(cuy.data)
|
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
|
||||||
@test haskey(Flux.CUDA.descs, curnn.cell)
|
|
||||||
|
|
||||||
Δ = randn(size(y))
|
@test cpu(data(cy)) ≈ data(y)
|
||||||
|
|
||||||
Flux.back!(y, Δ)
|
g = rand(size(y)...)
|
||||||
Flux.back!(cuy, gpu(Δ))
|
Flux.back!(y, g)
|
||||||
|
Flux.back!(cy, gpu(g))
|
||||||
|
|
||||||
@test x.grad ≈ collect(cux.grad)
|
@test m.γ.grad ≈ cpu(cm.γ.grad)
|
||||||
@test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad)
|
@test m.β.grad ≈ cpu(cm.β.grad)
|
||||||
@test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad)
|
@test x.grad ≈ cpu(x.grad)
|
||||||
@test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad)
|
end
|
||||||
@test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad)
|
|
||||||
if isdefined(rnn.cell, :c)
|
@testset "2D Input" begin
|
||||||
@test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad)
|
x = TrackedArray(Float64.(collect(reshape(1:12, 3, 4))))
|
||||||
end
|
m = BatchNorm(3)
|
||||||
|
cx = gpu(x)
|
||||||
Flux.reset!(rnn)
|
cm = gpu(m)
|
||||||
Flux.reset!(curnn)
|
|
||||||
ohx = batch_size == 1 ?
|
y = m(x)
|
||||||
Flux.onehot(rand(1:10), 1:10) :
|
cy = cm(cx)
|
||||||
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
|
||||||
cuohx = gpu(ohx)
|
@test cy isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||||
y = (rnn(ohx); rnn(ohx))
|
|
||||||
cuy = (curnn(cuohx); curnn(cuohx))
|
@test cpu(data(cy)) ≈ data(y)
|
||||||
|
|
||||||
@test y.data ≈ collect(cuy.data)
|
g = rand(size(y)...)
|
||||||
|
Flux.back!(y, g)
|
||||||
|
Flux.back!(cy, gpu(g))
|
||||||
|
|
||||||
|
@test m.γ.grad ≈ cpu(cm.γ.grad)
|
||||||
|
@test m.β.grad ≈ cpu(cm.β.grad)
|
||||||
|
@test x.grad ≈ cpu(x.grad)
|
||||||
end
|
end
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
46
test/cuda/curnn.jl
Normal file
46
test/cuda/curnn.jl
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
using Flux, CuArrays, Test
|
||||||
|
|
||||||
|
@testset "RNN" begin
|
||||||
|
@testset for R in [RNN, GRU, LSTM]
|
||||||
|
rnn = R(10, 5)
|
||||||
|
curnn = mapleaves(gpu, rnn)
|
||||||
|
@testset for batch_size in (1, 5)
|
||||||
|
Flux.reset!(rnn)
|
||||||
|
Flux.reset!(curnn)
|
||||||
|
x = batch_size == 1 ?
|
||||||
|
param(rand(10)) :
|
||||||
|
param(rand(10,batch_size))
|
||||||
|
cux = gpu(x)
|
||||||
|
y = (rnn(x); rnn(x))
|
||||||
|
cuy = (curnn(cux); curnn(cux))
|
||||||
|
|
||||||
|
@test y.data ≈ collect(cuy.data)
|
||||||
|
@test haskey(Flux.CUDA.descs, curnn.cell)
|
||||||
|
|
||||||
|
Δ = randn(size(y))
|
||||||
|
|
||||||
|
Flux.back!(y, Δ)
|
||||||
|
Flux.back!(cuy, gpu(Δ))
|
||||||
|
|
||||||
|
@test x.grad ≈ collect(cux.grad)
|
||||||
|
@test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad)
|
||||||
|
@test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad)
|
||||||
|
@test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad)
|
||||||
|
@test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad)
|
||||||
|
if isdefined(rnn.cell, :c)
|
||||||
|
@test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad)
|
||||||
|
end
|
||||||
|
|
||||||
|
Flux.reset!(rnn)
|
||||||
|
Flux.reset!(curnn)
|
||||||
|
ohx = batch_size == 1 ?
|
||||||
|
Flux.onehot(rand(1:10), 1:10) :
|
||||||
|
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
||||||
|
cuohx = gpu(ohx)
|
||||||
|
y = (rnn(ohx); rnn(ohx))
|
||||||
|
cuy = (curnn(cuohx); curnn(cuohx))
|
||||||
|
|
||||||
|
@test y.data ≈ collect(cuy.data)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -1,4 +1,5 @@
|
|||||||
using Flux: testmode!
|
using Flux: testmode!
|
||||||
|
using Flux.Tracker: data
|
||||||
|
|
||||||
@testset "Dropout" begin
|
@testset "Dropout" begin
|
||||||
x = [1.,2.,3.]
|
x = [1.,2.,3.]
|
||||||
@ -28,7 +29,8 @@ using Flux: testmode!
|
|||||||
end
|
end
|
||||||
|
|
||||||
@testset "BatchNorm" begin
|
@testset "BatchNorm" begin
|
||||||
let m = BatchNorm(2), x = param([1 2; 3 4; 5 6]')
|
let m = BatchNorm(2), x = param([1 3 5;
|
||||||
|
2 4 6])
|
||||||
|
|
||||||
@test m.β.data == [0, 0] # initβ(2)
|
@test m.β.data == [0, 0] # initβ(2)
|
||||||
@test m.γ.data == [1, 1] # initγ(2)
|
@test m.γ.data == [1, 1] # initγ(2)
|
||||||
@ -53,29 +55,30 @@ end
|
|||||||
# .1 * 4 + 0 = .4
|
# .1 * 4 + 0 = .4
|
||||||
@test m.μ ≈ reshape([0.3, 0.4], 2, 1)
|
@test m.μ ≈ reshape([0.3, 0.4], 2, 1)
|
||||||
|
|
||||||
# julia> .1 .* std(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
# julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||||
# 2×1 Array{Float64,2}:
|
# 2×1 Array{Float64,2}:
|
||||||
# 1.14495
|
# 1.3
|
||||||
# 1.14495
|
# 1.3
|
||||||
@test m.σ ≈ .1 .* std(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
@test m.σ² ≈ .1 .* var(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||||
|
|
||||||
testmode!(m)
|
testmode!(m)
|
||||||
@test !m.active
|
@test !m.active
|
||||||
|
|
||||||
x′ = m(x).data
|
x′ = m(x).data
|
||||||
@test x′[1] ≈ (1 .- 0.3) / 1.1449489742783179
|
@test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5)
|
||||||
end
|
end
|
||||||
|
|
||||||
# with activation function
|
# with activation function
|
||||||
let m = BatchNorm(2, σ), x = param([1 2; 3 4; 5 6]')
|
let m = BatchNorm(2, sigmoid), x = param([1 3 5;
|
||||||
|
2 4 6])
|
||||||
@test m.active
|
@test m.active
|
||||||
m(x)
|
m(x)
|
||||||
|
|
||||||
testmode!(m)
|
testmode!(m)
|
||||||
@test !m.active
|
@test !m.active
|
||||||
|
|
||||||
x′ = m(x).data
|
y = m(x).data
|
||||||
@test x′[1] ≈ σ((1 - 0.3) / 1.1449489742783179)
|
@test isapprox(y, data(sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ))), atol = 1.0e-7)
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1))
|
let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1))
|
||||||
@ -85,7 +88,7 @@ end
|
|||||||
end
|
end
|
||||||
|
|
||||||
let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1))
|
let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1))
|
||||||
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
||||||
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
||||||
@test m(x) == y
|
@test m(x) == y
|
||||||
end
|
end
|
||||||
|
@ -13,7 +13,7 @@ if Base.JLOptions().check_bounds == 1
|
|||||||
exit()
|
exit()
|
||||||
end
|
end
|
||||||
|
|
||||||
using Flux, Test, Random
|
using Flux, Test, Random, Statistics
|
||||||
using Random
|
using Random
|
||||||
|
|
||||||
Random.seed!(0)
|
Random.seed!(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user