From 7ac9e191cbd2d9fb235d48bd023178c70778f7e5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 28 Jun 2018 14:25:22 +0530 Subject: [PATCH] Revert 1 change --- src/cuda/cuda.jl | 1 - src/cuda/curnn.jl | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl index d0e14bf4..f2b05aca 100644 --- a/src/cuda/cuda.jl +++ b/src/cuda/cuda.jl @@ -3,7 +3,6 @@ module CUDA using CuArrays if CuArrays.cudnn_available() - CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}} include("curnn.jl") include("cudnn.jl") end diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index 94254f91..905b1ef4 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -234,6 +234,7 @@ function copy_transpose!(dst::CuArray, src::CuArray) return dst end +CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}} CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}} CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}} CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}