Minor fix for 5D support
This commit is contained in:
parent
681d8c4dfc
commit
5ccde88ce6
@ -3,8 +3,9 @@ module CUDA
|
||||
using CuArrays
|
||||
|
||||
if CuArrays.cudnn_available()
|
||||
include("cudnn.jl")
|
||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
||||
include("curnn.jl")
|
||||
include("cudnn.jl")
|
||||
end
|
||||
|
||||
end
|
||||
|
@ -41,7 +41,7 @@ BNCache() = BNCache(nothing, nothing)
|
||||
# CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
|
||||
# so use the native julia code when doing batchnorm on a 2D Array
|
||||
|
||||
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 4},
|
||||
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::Union{CuArray{T, 4},CuArray{T,5}},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||
cache = nothing, alpha = T(1), beta = T(0),
|
||||
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
||||
@ -179,11 +179,7 @@ end
|
||||
import ..Flux: Flux
|
||||
import ..Tracker: track, back, @back, istracked, TrackedArray
|
||||
|
||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
||||
CuParam45{T} = Union{CuParam{T,4},CuParam{T,5}}
|
||||
CuBatchNorm{T} = Flux.BatchNorm{<:Union{typeof(identity),typeof(relu)},<:CuParam{T,1},<:CuParam{T,1},<:T}
|
||||
|
||||
(BN::BatchNorm)(x::CuParam45{T}) =
|
||||
(BN::BatchNorm)(x::Union{CuParam{T,4},CuParam{T,5}}) where T<:Union{Float32, Float64} =
|
||||
batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = nothing, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)
|
||||
|
||||
_batchnorm(g, b, x, running_mean, running_var, momentum,
|
||||
|
@ -234,7 +234,6 @@ function copy_transpose!(dst::CuArray, src::CuArray)
|
||||
return dst
|
||||
end
|
||||
|
||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
|
Loading…
Reference in New Issue
Block a user