Minor fix for 5D support

This commit is contained in:
Avik Pal 2018-06-28 14:21:17 +05:30
parent 681d8c4dfc
commit 5ccde88ce6
3 changed files with 4 additions and 8 deletions

View File

@ -3,8 +3,9 @@ module CUDA
using CuArrays
if CuArrays.cudnn_available()
include("cudnn.jl")
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
include("curnn.jl")
include("cudnn.jl")
end
end

View File

@ -41,7 +41,7 @@ BNCache() = BNCache(nothing, nothing)
# CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
# so use the native julia code when doing batchnorm on a 2D Array
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 4},
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::Union{CuArray{T, 4},CuArray{T,5}},
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
cache = nothing, alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
@ -179,11 +179,7 @@ end
import ..Flux: Flux
import ..Tracker: track, back, @back, istracked, TrackedArray
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
CuParam45{T} = Union{CuParam{T,4},CuParam{T,5}}
CuBatchNorm{T} = Flux.BatchNorm{<:Union{typeof(identity),typeof(relu)},<:CuParam{T,1},<:CuParam{T,1},<:T}
(BN::BatchNorm)(x::CuParam45{T}) =
(BN::BatchNorm)(x::Union{CuParam{T,4},CuParam{T,5}}) where T<:Union{Float32, Float64} =
batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = nothing, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)
_batchnorm(g, b, x, running_mean, running_var, momentum,

View File

@ -234,7 +234,6 @@ function copy_transpose!(dst::CuArray, src::CuArray)
return dst
end
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}