From 5ccde88ce61e777b125a3638c0621fd4a80c0031 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 28 Jun 2018 14:21:17 +0530 Subject: [PATCH] Minor fix for 5D support --- src/cuda/cuda.jl | 3 ++- src/cuda/cudnn.jl | 8 ++------ src/cuda/curnn.jl | 1 - 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index 764bb96f..d0e14bf4 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -3,8 +3,9 @@ module CUDA using CuArrays if CuArrays.cudnn_available() - include("cudnn.jl") + CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}} include("curnn.jl") + include("cudnn.jl") end end diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 088876e4..100f9f4b 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -41,7 +41,7 @@ BNCache() = BNCache(nothing, nothing) # CuDNN supports only 4D and 5D Tensors for BatchNorm Operations # so use the native julia code when doing batchnorm on a 2D Array -function batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 4}, +function batchnorm(g::CuArray{T}, b::CuArray{T}, x::Union{CuArray{T, 4},CuArray{T,5}}, running_mean::CuArray{T}, running_var::CuArray{T}, momentum; cache = nothing, alpha = T(1), beta = T(0), eps = T(1e-5), training = true) where T<:Union{Float32, Float64} @@ -179,11 +179,7 @@ end import ..Flux: Flux import ..Tracker: track, back, @back, istracked, TrackedArray -CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}} -CuParam45{T} = Union{CuParam{T,4},CuParam{T,5}} -CuBatchNorm{T} = Flux.BatchNorm{<:Union{typeof(identity),typeof(relu)},<:CuParam{T,1},<:CuParam{T,1},<:T} - -(BN::BatchNorm)(x::CuParam45{T}) = +(BN::BatchNorm)(x::Union{CuParam{T,4},CuParam{T,5}}) where T<:Union{Float32, Float64} = batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = nothing, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active) _batchnorm(g, b, x, running_mean, running_var, momentum, diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index 905b1ef4..94254f91 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -234,7 +234,6 @@ function copy_transpose!(dst::CuArray, src::CuArray) return dst end -CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}} CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}} CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}} CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}