diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index d0e14bf4..f2b05aca 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -3,7 +3,6 @@ module CUDA using CuArrays if CuArrays.cudnn_available() - CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}} include("curnn.jl") include("cudnn.jl") end diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index 94254f91..905b1ef4 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -234,6 +234,7 @@ function copy_transpose!(dst::CuArray, src::CuArray) return dst end +CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}} CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}} CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}} CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}