Minor fix for 5D support
This commit is contained in:
parent
681d8c4dfc
commit
5ccde88ce6
@ -3,8 +3,9 @@ module CUDA
|
|||||||
using CuArrays
|
using CuArrays
|
||||||
|
|
||||||
if CuArrays.cudnn_available()
|
if CuArrays.cudnn_available()
|
||||||
include("cudnn.jl")
|
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
||||||
include("curnn.jl")
|
include("curnn.jl")
|
||||||
|
include("cudnn.jl")
|
||||||
end
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -41,7 +41,7 @@ BNCache() = BNCache(nothing, nothing)
|
|||||||
# CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
|
# CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
|
||||||
# so use the native julia code when doing batchnorm on a 2D Array
|
# so use the native julia code when doing batchnorm on a 2D Array
|
||||||
|
|
||||||
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 4},
|
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::Union{CuArray{T, 4},CuArray{T,5}},
|
||||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||||
cache = nothing, alpha = T(1), beta = T(0),
|
cache = nothing, alpha = T(1), beta = T(0),
|
||||||
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
||||||
@ -179,11 +179,7 @@ end
|
|||||||
import ..Flux: Flux
|
import ..Flux: Flux
|
||||||
import ..Tracker: track, back, @back, istracked, TrackedArray
|
import ..Tracker: track, back, @back, istracked, TrackedArray
|
||||||
|
|
||||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
(BN::BatchNorm)(x::Union{CuParam{T,4},CuParam{T,5}}) where T<:Union{Float32, Float64} =
|
||||||
CuParam45{T} = Union{CuParam{T,4},CuParam{T,5}}
|
|
||||||
CuBatchNorm{T} = Flux.BatchNorm{<:Union{typeof(identity),typeof(relu)},<:CuParam{T,1},<:CuParam{T,1},<:T}
|
|
||||||
|
|
||||||
(BN::BatchNorm)(x::CuParam45{T}) =
|
|
||||||
batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = nothing, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)
|
batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = nothing, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)
|
||||||
|
|
||||||
_batchnorm(g, b, x, running_mean, running_var, momentum,
|
_batchnorm(g, b, x, running_mean, running_var, momentum,
|
||||||
|
@ -234,7 +234,6 @@ function copy_transpose!(dst::CuArray, src::CuArray)
|
|||||||
return dst
|
return dst
|
||||||
end
|
end
|
||||||
|
|
||||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
|
||||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
|
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
|
||||||
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
|
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||||
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
|
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||||
|
Loading…
Reference in New Issue
Block a user