diff --git a/src/Flux.jl b/src/Flux.jl index 87a37566..30baf2bd 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -36,4 +36,6 @@ include("layers/normalisation.jl") include("data/Data.jl") +@require CuArrays include("cuda/cuda.jl") + end # module diff --git a/src/cuda/cuda.jl b/src/cuda/cuda.jl new file mode 100644 index 00000000..eaa3fe00 --- /dev/null +++ b/src/cuda/cuda.jl @@ -0,0 +1,7 @@ +module CUDA + +using CuArrays + +CuArrays.cudnn_available() && include("cudnn.jl") + +end diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl new file mode 100644 index 00000000..decc91ae --- /dev/null +++ b/src/cuda/cudnn.jl @@ -0,0 +1,368 @@ +using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle, + cudnnDataType, TensorDesc, FilterDesc + +mutable struct DropoutDesc + ptr::Ptr{Void} + states::CuVector{UInt8} +end + +Base.unsafe_convert(::Type{Ptr{Void}}, dd::DropoutDesc) = dd.ptr + +function DropoutDesc(ρ::Real; seed::Integer=0) + d = [C_NULL] + s = Csize_t[0] + @check ccall((:cudnnCreateDropoutDescriptor,libcudnn), cudnnStatus_t, (Ptr{Ptr{Void}},), d) + @check ccall((:cudnnDropoutGetStatesSize,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Csize_t}),libcudnn_handle[],s) + states = CuArray{UInt8}(s[]) # TODO: can we drop this when ρ=0? + desc = DropoutDesc(d[], states) + @check ccall((:cudnnSetDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},Ptr{Void},Cfloat,Ptr{Void},Csize_t,Culonglong), + desc,libcudnn_handle[],ρ,states,length(states),seed) + finalizer(desc, x -> + @check ccall((:cudnnDestroyDropoutDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x)) + return desc +end + +const RNN_RELU = 0 # Stock RNN with ReLu activation +const RNN_TANH = 1 # Stock RNN with tanh activation +const LSTM = 2 # LSTM with no peephole connections +const GRU = 3 # Using h' = tanh(r * Uh(t-1) + Wx) and h = (1 - z) * h' + z * h(t-1) + +const LINEAR_INPUT = 0 +const SKIP_INPUT = 1 + +const UNIDIRECTIONAL = 0 +const BIDIRECTIONAL = 1 + +const RNN_ALGO_STANDARD = 0 +const RNN_ALGO_PERSIST_STATIC = 1 +const RNN_ALGO_PERSIST_DYNAMIC = 2 + +# param layout: +# RNN: [weight, bias] × [input, hidden] +# GRU: [weight, bias] × [input, hidden] × [reset, update, newmem] +# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output] + +function params(w::CuVector, input, hidden, n = 1) + slice(offset, shape) = reshape(w[offset+(1:prod(shape))], shape) + wx = slice(0, (input, hidden*n)) + wh = slice(length(wx), (hidden, hidden*n)) + bias = w[length(wx)+length(wh) + (1:hidden*n)] + (wx, wh), bias +end + +mutable struct RNNDesc{T} + mode::Int + input::Int + hidden::Int + params::CuVector{T} + weights::NTuple{2,CuMatrix{T}} + bias::CuVector{T} + ptr::Ptr{Void} +end + +Base.unsafe_convert(::Type{Ptr{Void}}, d::RNNDesc) = d.ptr + +function rnnParamSize(T, r, input) + size = Csize_t[0] + @check ccall((:cudnnGetRNNParamsSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Csize_t},Cint), + libcudnn_handle[], r, TensorDesc(T, (1,input,1)), size, cudnnDataType(T)) + return Int(size[])÷sizeof(T) +end + +ngates(mode) = [1, 1, 4, 3][mode+1] +ngates(r::RNNDesc) = ngates(r.mode) + +function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T + d = [C_NULL] + @check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d) + + dropoutDesc = DropoutDesc(0) + inputMode = LINEAR_INPUT + direction = UNIDIRECTIONAL + algo = RNN_ALGO_STANDARD + @check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint), + libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T)) + + w = cuzeros(T, rnnParamSize(T, d[], input)) + # TODO: avoid reserve allocation here + rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[]) + finalizer(rd, x -> + @check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x)) + return rd +end + +function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc) + size = Csize_t[0] + @check ccall((:cudnnGetRNNWorkspaceSize, libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Ptr{Ptr{Void}},Ptr{Csize_t}), + libcudnn_handle[], r, seqlen, xdesc, size) + return Int(size[]) +end + +const workspace = [CuVector{UInt8}(1)] + +getworkspace(bytes) = + length(workspace[]) ≥ bytes ? + workspace[] : + (workspace[] = CuVector{UInt8}(bytes)) + +getworkspace(r::RNNDesc, seqlen, xdesc) = + getworkspace(rnnWorkspaceSize(r, seqlen, xdesc)) + +function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc) + size = Csize_t[0] + @check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Void}, Ptr{Void}, Cint, Ptr{Ptr{Void}}, Ptr{Csize_t}), + libcudnn_handle[], r, seqlen, xdesc, size) + return Int(size[]) +end + +function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, + workspace, reserve=nothing) where T + if reserve == nothing + @check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t, + (Ptr{Void}, Ptr{Void}, Cint, + Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, + Ptr{Void}, Csize_t), + libcudnn_handle[], rnn, seqlen, + xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, + workspace, length(workspace)) + else + @check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t, + (Ptr{Void}, Ptr{Void}, Cint, + Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, + Ptr{Void}, Csize_t, Ptr{Void}, Csize_t), + libcudnn_handle[], rnn, seqlen, + xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, + workspace, length(workspace), reserve, length(reserve)) + end +end + +xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))] + +hDesc(h::Void) = C_NULL, C_NULL +hDesc(x::Integer) = (@assert x == 0; hDesc(nothing)) +function hDesc(h::CuArray) + TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h +end + +# TODO: can we just manipulate strides here? +# TODO: should use repmat, but this isn't implemented. +hBatch(x::AbstractVector, h::CuVector) = h +hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2)) +hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1) + +function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T + h = hBatch(x, h_) + c = c_ == nothing ? nothing : hBatch(x, c_) + @assert size(x, 1) == rnn.input + @assert size(h, 1) == rnn.hidden + @assert size(x, 2) == size(h, 2) + seqLength = 1 + xdesc = xDesc(x) + y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2)) + ho = similar(h) + ydesc = xDesc(y) + workspace = getworkspace(rnn, seqLength, xdesc) + reserve = train == Val{true} ? + CuVector{UInt8}(rnnTrainingReserveSize(rnn, seqLength, xdesc)) : + nothing + co = c == nothing ? c : similar(c) + cudnnRNNForward(rnn, seqLength, + xdesc, x, + hDesc(h)..., + hDesc(c)..., + FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, + ydesc, y, + hDesc(ho)..., + hDesc(co)..., + workspace, reserve) + result = c == nothing ? (y, ho) : (y, ho, co) + return train == Val{true} ? (reserve, result) : result +end + +forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T = + forward(rnn, x, h, c, Val{true}) + +function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, + wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T + @check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t, + (Ptr{Void}, Ptr{Void}, Cint, + Ptr{Ptr{Void}}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, + Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, + Ptr{Void}, Csize_t, Ptr{Void}, Csize_t), + libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, + wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs)) +end + +function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T + # Same as above, any more efficient way? + dy = dy_ isa Integer ? zeros(y) : dy_ + yd = xDesc(y) + dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2)) + dh = similar(h) + dc = c == nothing ? nothing : similar(c) + cudnnRNNBackwardData(rnn, 1, + yd, y, yd, dy, hDesc(dho)..., hDesc(dco)..., + FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, + hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)..., + workspace[], reserve) + return c == nothing ? (dx, dh) : (dx, dh, dc) +end + +backwardData(rnn, y, dy, dho, hx, reserve) = + backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve) + +function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw, + workspace, reserve) where T + @check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t, + (Ptr{Void}, Ptr{Void}, Cint, # handle, rnnDesc, seqLength + Ptr{Ptr{Void}}, Ptr{T}, #x + Ptr{Void}, Ptr{T}, #hx + Ptr{Ptr{Void}}, Ptr{T}, #y + Ptr{Void}, Csize_t, #ws + Ptr{Void}, Ptr{T}, #dw + Ptr{Void}, Csize_t), #rs + libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y, + workspace, length(workspace), dwd, dw, reserve, length(reserve)) +end + +function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T + dw = zeros(rnn.params) + cudnnRNNBackwardWeights(rnn, 1, + xDesc(x), x, hDesc(h)..., xDesc(y), y, + FilterDesc(T, (1, 1, length(dw))), dw, + workspace[], reserve) + return params(dw, rnn.input, rnn.hidden, ngates(rnn)) +end + +# Interface + +import ..Flux: Flux, relu +import ..Flux.Tracker: TrackedArray +using CUDAnative +using CuArrays: @cuindex, cudims + +function copy_transpose!(dst::CuArray, src::CuArray) + function kernel(dst, src) + I = @cuindex dst + dst[I...] = src[reverse(I)...] + return + end + blk, thr = cudims(dst) + @cuda (blk, thr) kernel(dst, src) + return dst +end + +CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}} +CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}} +CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}} +CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}} +CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}} + +function copyparams!(m::CuRNNs, d::RNNDesc) + Wi, Wh = d.weights + copy_transpose!(Wi, Flux.data(m.Wi)) + copy_transpose!(Wh, Flux.data(m.Wh)) + copy_transpose!(d.bias, Flux.data(m.b)) + return +end + +function RNNDesc(m::CuRNNs{T}) where T + h, i = length(m.h), size(m.Wi, 2) + mode = m isa CuRNN ? + (m.σ == tanh ? RNN_TANH : RNN_RELU) : + m isa CuGRU ? GRU : LSTM + r = RNNDesc{T}(mode, i, h) + return r +end + +const descs = WeakKeyDict() + +function desc(rnn) + d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = RNNDesc(rnn)) + copyparams!(rnn, d) + return d +end + +import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast + +mutable struct RNNCall{R} + rnn::R + reserve::CuVector{UInt8} + RNNCall{R}(rnn::R) where R = new(rnn) +end + +RNNCall(rnn) = RNNCall{typeof(rnn)}(rnn) + +function (c::RNNCall)(args...) + rs, result = forwardTrain(desc(c.rnn), args...) + c.reserve = rs + return result +end + +istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...)) + +function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} + result = istrain(m, h, x) ? + track(RNNCall(m), x, h) : + forward(desc(m), x, h) + return result[2], result[1] +end + +function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} + result = istrain(m, h, x) ? + track(RNNCall(m), x, h) : + forward(desc(m), x, h) + return result[2], result[1] +end + +function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64} + result = istrain(m, h, x) ? + track(RNNCall(m), x, h[1], h[2]) : + forward(desc(m), x, h[1], h[2]) + return (result[2], result[3]), result[1] +end + +function accum_transpose!(dst::CuArray, src::CuArray) + function kernel(dst, src) + I = @cuindex dst + dst[I...] += src[reverse(I)...] + return + end + blk, thr = cudims(dst) + @cuda (blk, thr) kernel(dst, src) + return dst +end + +function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h) + y, ho = y_ + dy, dho = Δ + h_ = hBatch(x, data(h)) + dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve) + @back(x, dx) + @back(h, unbroadcast(h, dh)) + (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve) + # We don't have to make this assumption, it's just slightly more complex. + @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) + istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) + istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh) + istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db) +end + +function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c) + y, ho, co = y_ + dy, dho, dco = Δ + h_ = hBatch(x, data(h)) + c_ = hBatch(x, data(c)) + dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve) + @back(x, dx) + @back(h, unbroadcast(h, dh)) + @back(c, unbroadcast(h, dc)) + (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve) + @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) + istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) + istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh) + istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db) +end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index e4eb0c3d..41b66bda 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -1,7 +1,6 @@ -# TODO: broadcasting cat -combine(x::AbstractMatrix, h::AbstractVector) = vcat(x, h .* trues(1, size(x, 2))) -combine(x::AbstractVector, h::AbstractVector) = vcat(x, h) -combine(x::AbstractMatrix, h::AbstractMatrix) = vcat(x, h) +gate(h, n) = (1:h) + h*(n-1) +gate(x::AbstractVector, h, n) = x[gate(h,n)] +gate(x::AbstractMatrix, h, n) = x[gate(h,n),:] # Stateful recurrence @@ -74,16 +73,22 @@ flip(f, xs) = reverse(f.(reverse(xs))) # Vanilla RNN -struct RNNCell{D,V} - d::D +mutable struct RNNCell{F,A,V} + σ::F + Wi::A + Wh::A + b::V h::V end -RNNCell(in::Integer, out::Integer, σ = tanh; initW = glorot_uniform, initb = zeros) = - RNNCell(Dense(in+out, out, σ, initW = initW, initb = initb), param(initW(out))) +RNNCell(in::Integer, out::Integer, σ = tanh; + init = glorot_uniform) = + RNNCell(σ, param(init(out, in)), param(init(out, out)), + param(zeros(out)), param(initn(out))) function (m::RNNCell)(h, x) - h = m.d(combine(x, h)) + σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b + h = σ.(Wi*x .+ Wh*h .+ b) return h, h end @@ -91,8 +96,10 @@ hidden(m::RNNCell) = m.h treelike(RNNCell) -function Base.show(io::IO, m::RNNCell) - print(io, "RNNCell(", m.d, ")") +function Base.show(io::IO, l::RNNCell) + print(io, "RNNCell(", size(l.Wi, 2), ", ", size(l.Wi, 1)) + l.σ == identity || print(io, ", ", l.σ) + print(io, ")") end """ @@ -105,40 +112,41 @@ RNN(a...; ka...) = Recur(RNNCell(a...; ka...)) # LSTM -struct LSTMCell{D1,D2,V} - forget::D1 - input::D1 - output::D1 - cell::D2 - h::V; c::V +mutable struct LSTMCell{A,V} + Wi::A + Wh::A + b::V + h::V + c::V end -function LSTMCell(in, out; initW = glorot_uniform, initb = zeros) - cell = LSTMCell([Dense(in+out, out, σ, initW = initW, initb = initb) for _ = 1:3]..., - Dense(in+out, out, tanh, initW = initW, initb = initb), - param(initW(out)), param(initW(out))) - cell.forget.b.data .= 1 +function LSTMCell(in::Integer, out::Integer; + init = glorot_uniform) + cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)), + param(initn(out)), param(initn(out))) + cell.b.data[gate(out, 2)] = 1 return cell end function (m::LSTMCell)(h_, x) - h, c = h_ - x′ = combine(x, h) - forget, input, output, cell = - m.forget(x′), m.input(x′), m.output(x′), m.cell(x′) + h, c = h_ # TODO: nicer syntax on 0.7 + b, o = m.b, size(h, 1) + g = m.Wi*x .+ m.Wh*h .+ b + input = σ.(gate(g, o, 1)) + forget = σ.(gate(g, o, 2)) + cell = tanh.(gate(g, o, 3)) + output = σ.(gate(g, o, 4)) c = forget .* c .+ input .* cell - h = output .* tanh.(c) - return (h, c), h + h′ = output .* tanh.(c) + return (h′, c), h′ end hidden(m::LSTMCell) = (m.h, m.c) treelike(LSTMCell) -Base.show(io::IO, m::LSTMCell) = - print(io, "LSTMCell(", - size(m.forget.W, 2) - size(m.forget.W, 1), ", ", - size(m.forget.W, 1), ')') +Base.show(io::IO, l::LSTMCell) = + print(io, "LSTMCell(", size(l.Wi, 2), ", ", size(l.Wi, 1), ")") """ LSTM(in::Integer, out::Integer, σ = tanh) @@ -153,38 +161,33 @@ LSTM(a...; ka...) = Recur(LSTMCell(a...; ka...)) # GRU -struct GRUCell{D1,D2,V} - update::D1 - reset::D1 - candidate::D2 +mutable struct GRUCell{A,V} + Wi::A + Wh::A + b::V h::V end -function GRUCell(in, out) - cell = GRUCell(Dense(in+out, out, σ), - Dense(in+out, out, σ), - Dense(in+out, out, tanh), - param(initn(out))) - return cell -end +GRUCell(in, out; init = glorot_uniform) = + GRUCell(param(init(out*3, in)), param(init(out*3, out)), + param(zeros(out*3)), param(initn(out))) function (m::GRUCell)(h, x) - x′ = combine(x, h) - z = m.update(x′) - r = m.reset(x′) - h̃ = m.candidate(combine(r.*h, x)) - h = (1.-z).*h .+ z.*h̃ - return h, h + b, o = m.b, size(h, 1) + gx, gh = m.Wi*x, m.Wh*h + r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1)) + z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2)) + h̃ = tanh.(gate(gx, o, 3) .+ r .* gate(gh, o, 3) .+ gate(b, o, 3)) + h′ = (1.-z).*h̃ .+ z.*h + return h′, h′ end hidden(m::GRUCell) = m.h treelike(GRUCell) -Base.show(io::IO, m::GRUCell) = - print(io, "GRUCell(", - size(m.update.W, 2) - size(m.update.W, 1), ", ", - size(m.update.W, 1), ')') +Base.show(io::IO, l::GRUCell) = + print(io, "GRUCell(", size(l.Wi, 2), ", ", size(l.Wi, 1), ")") """ GRU(in::Integer, out::Integer, σ = tanh) diff --git a/test/cuarrays.jl b/test/cuda/cuda.jl similarity index 91% rename from test/cuarrays.jl rename to test/cuda/cuda.jl index e4fd3d58..91d3701b 100644 --- a/test/cuarrays.jl +++ b/test/cuda/cuda.jl @@ -21,3 +21,5 @@ cm = cu(m) @test cm(cu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}} end + +CuArrays.cudnn_available() && include("cudnn.jl") diff --git a/test/cuda/cudnn.jl b/test/cuda/cudnn.jl new file mode 100644 index 00000000..85ab21c8 --- /dev/null +++ b/test/cuda/cudnn.jl @@ -0,0 +1,31 @@ +using Flux, CuArrays, Base.Test + +info("Testing Flux/CUDNN") + +@testset "RNN" begin + @testset for R in [RNN, GRU, LSTM] + x = param(rand(10,5)) + cux = cu(x) + rnn = R(10, 5) + curnn = mapleaves(cu, rnn) + y = (rnn(x); rnn(x)) + cuy = (curnn(cux); curnn(cux)) + + @test y.data ≈ collect(cuy.data) + @test haskey(Flux.CUDA.descs, curnn.cell) + + Δ = randn(size(y)) + + Flux.back!(y, Δ) + Flux.back!(cuy, cu(Δ)) + + @test x.grad ≈ collect(cux.grad) + @test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad) + @test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad) + @test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad) + @test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad) + if isdefined(rnn.cell, :c) + @test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 5bdb675d..47f7e9e5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,7 @@ using Flux, Base.Test +srand(0) + @testset "Flux" begin include("utils.jl") @@ -10,7 +12,7 @@ include("optimise.jl") include("data.jl") if Base.find_in_path("CuArrays") ≠ nothing - include("cuarrays.jl") + include("cuda/cuda.jl") end end