fix cuda code and tests

This commit is contained in:
Mike Innes 2019-08-19 16:56:48 +01:00
parent 62ec01a6f5
commit 487000ac31
3 changed files with 84 additions and 67 deletions

View File

@ -268,48 +268,55 @@ end
using ..Flux: @adjoint
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
result = forward(desc(m), x, h)
return result[2], result[1]
y, h = forward(desc(m), x, h)
return h, y
end
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
result = forward(desc(m), x, h)
return result[2], result[1]
y, h = forward(desc(m), x, h)
return h, y
end
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
result = forward(desc(m), x, h[1], h[2])
return (result[2], result[3]), result[1]
y, h, c = forward(desc(m), x, h[1], h[2])
return (h, c), y
end
(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
unbroadcast(x::AbstractArray, Δ) =
size(x) == size(Δ) ? Δ :
length(x) == length(Δ) ? trim(x, Δ) :
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
for RNN in (CuRNN, CuGRU)
@eval @adjoint function (m::$RNN)(x, h, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), x, h)
result, function (Δ)
y, ho = result
dy, dho = Δ
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
reserve, (y, ho) = forwardTrain(desc(m), x, h)
(ho, y), function (Δ)
dho, dy = Δ
h_ = hBatch(x, h)
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
(dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)
dm = Ref{Any}((σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
(dm, unbroadcast(h, dh), dx)
end
end
end
@adjoint function (m::CuLSTM)(x, h, c, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), x, h, c)
result, function (Δ)
y, ho = result
dy, dho, dco = Δ
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
reserve, (y, ho, co) = forwardTrain(desc(m), x, h, c)
((ho, co), y), function (Δ)
dhc, dy = Δ
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
h_ = hBatch(x, h)
c_ = hBatch(x, c)
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
transpose(dWi), transpose(dWh), db)
dm = Ref{Any}((Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
(dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
end
end

View File

@ -1,5 +1,5 @@
using Flux, CuArrays, Test
trainmode(f, x...) = forward(f, x...)[1]
using Flux: forward
@testset "CUDNN BatchNorm" begin
@testset "4D Input" begin
@ -8,16 +8,18 @@ trainmode(f, x...) = forward(f, x...)[1]
cx = gpu(x)
cm = gpu(m)
y = trainmode(m, x)
cy = trainmode(cm, cx)
y, back = forward((m, x) -> m(x), m, x)
cy, cback = forward((m, x) -> m(x), cm, cx)
@test cpu(cy) y
g = gradient(()->sum(m(x)), params(m))
cg = gradient(()->sum(cm(cx)), params(cm))
Δ = randn(size(y))
dm, dx = back(Δ)
cdm, cdx = cback(gpu(Δ))
@test g[m.γ] cpu(cg[cm.γ])
@test g[m.β] cpu(cg[cm.β])
@test dm[].γ cpu(cdm[].γ)
@test dm[].β cpu(cdm[].β)
@test dx cpu(cdx)
end
@testset "2D Input" begin
@ -26,17 +28,17 @@ trainmode(f, x...) = forward(f, x...)[1]
cx = gpu(x)
cm = gpu(m)
y = trainmode(m, x)
cy = trainmode(cm, cx)
@test cy isa CuArray{Float32,2}
y, back = forward((m, x) -> m(x), m, x)
cy, cback = forward((m, x) -> m(x), cm, cx)
@test cpu(cy) y
g = gradient(()->sum(m(x)), params(m))
cg = gradient(()->sum(cm(cx)), params(cm))
Δ = randn(size(y))
dm, dx = back(Δ)
cdm, cdx = cback(gpu(Δ))
@test g[m.γ] cpu(cg[cm.γ])
@test g[m.β] cpu(cg[cm.β])
@test dm[].γ cpu(cdm[].γ)
@test dm[].β cpu(cdm[].β)
@test dx cpu(cdx)
end
end

View File

@ -1,46 +1,54 @@
using Flux, CuArrays, Test
using Flux: forward
@testset "RNN" begin
@testset for R in [RNN, GRU, LSTM]
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
rnn = R(10, 5)
curnn = mapleaves(gpu, rnn)
@testset for batch_size in (1, 5)
Flux.reset!(rnn)
Flux.reset!(curnn)
x = batch_size == 1 ?
rand(10) :
rand(10, batch_size)
cux = gpu(x)
y = (rnn(x); rnn(x))
cuy = (curnn(cux); curnn(cux))
@test y collect(cuy)
@test haskey(Flux.CUDA.descs, curnn.cell)
Flux.reset!(rnn)
Flux.reset!(curnn)
x = batch_size == 1 ?
rand(10) :
rand(10, batch_size)
cux = gpu(x)
#Δ = randn(size(y))
y, back = forward((r, x) -> (r(x)), rnn, x)
cuy, cuback = forward((r, x) -> (r(x)), curnn, cux)
#Flux.back!(y, Δ)
#Flux.back!(cuy, gpu(Δ))
@test y collect(cuy)
@test haskey(Flux.CUDA.descs, curnn.cell)
@test x collect(cux)
@test rnn.cell.Wi collect(curnn.cell.Wi)
@test rnn.cell.Wh collect(curnn.cell.Wh)
@test rnn.cell.b collect(curnn.cell.b)
@test rnn.cell.h collect(curnn.cell.h)
if isdefined(rnn.cell, :c)
@test rnn.cell.c collect(curnn.cell.c)
= randn(size(y))
, = back()
cum̄, cux̄ = cuback(gpu())
[].cell[].Wi
[].state
cum̄[].state
@test collect(cux̄)
@test [].cell[].Wi collect(cum̄[].cell[].Wi)
@test [].cell[].Wh collect(cum̄[].cell[].Wh)
@test [].cell[].b collect(cum̄[].cell[].b)
if [].state isa Tuple
for (x, cx) in zip([].state, cum̄[].state)
@test x collect(cx)
end
Flux.reset!(rnn)
Flux.reset!(curnn)
ohx = batch_size == 1 ?
Flux.onehot(rand(1:10), 1:10) :
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
cuohx = gpu(ohx)
y = (rnn(ohx); rnn(ohx))
cuy = (curnn(cuohx); curnn(cuohx))
@test y collect(cuy)
else
@test [].state collect(cum̄[].state)
end
Flux.reset!(rnn)
Flux.reset!(curnn)
ohx = batch_size == 1 ?
Flux.onehot(rand(1:10), 1:10) :
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
cuohx = gpu(ohx)
y = (rnn(ohx); rnn(ohx))
cuy = (curnn(cuohx); curnn(cuohx))
@test y collect(cuy)
end
end