fix cuda code and tests
This commit is contained in:
parent
62ec01a6f5
commit
487000ac31
|
@ -268,48 +268,55 @@ end
|
|||
using ..Flux: @adjoint
|
||||
|
||||
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
result = forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
y, h′ = forward(desc(m), x, h)
|
||||
return h′, y
|
||||
end
|
||||
|
||||
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
result = forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
y, h′ = forward(desc(m), x, h)
|
||||
return h′, y
|
||||
end
|
||||
|
||||
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
result = forward(desc(m), x, h[1], h[2])
|
||||
return (result[2], result[3]), result[1]
|
||||
y, h′, c′ = forward(desc(m), x, h[1], h[2])
|
||||
return (h′, c′), y
|
||||
end
|
||||
|
||||
(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
|
||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||
|
||||
unbroadcast(x::AbstractArray, Δ) =
|
||||
size(x) == size(Δ) ? Δ :
|
||||
length(x) == length(Δ) ? trim(x, Δ) :
|
||||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||
|
||||
for RNN in (CuRNN, CuGRU)
|
||||
@eval @adjoint function (m::$RNN)(x, h, Wi, Wh, b)
|
||||
reserve, result = forwardTrain(desc(m), x, h)
|
||||
result, function (Δ)
|
||||
y, ho = result
|
||||
dy, dho = Δ
|
||||
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
reserve, (y, ho) = forwardTrain(desc(m), x, h)
|
||||
(ho, y), function (Δ)
|
||||
dho, dy = Δ
|
||||
h_ = hBatch(x, h)
|
||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||
(dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)
|
||||
dm = Ref{Any}((σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
|
||||
(dm, unbroadcast(h, dh), dx)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@adjoint function (m::CuLSTM)(x, h, c, Wi, Wh, b)
|
||||
reserve, result = forwardTrain(desc(m), x, h, c)
|
||||
result, function (Δ)
|
||||
y, ho = result
|
||||
dy, dho, dco = Δ
|
||||
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
reserve, (y, ho, co) = forwardTrain(desc(m), x, h, c)
|
||||
((ho, co), y), function (Δ)
|
||||
dhc, dy = Δ
|
||||
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
|
||||
h_ = hBatch(x, h)
|
||||
c_ = hBatch(x, c)
|
||||
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
|
||||
transpose(dWi), transpose(dWh), db)
|
||||
dm = Ref{Any}((Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
|
||||
(dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Flux, CuArrays, Test
|
||||
trainmode(f, x...) = forward(f, x...)[1]
|
||||
using Flux: forward
|
||||
|
||||
@testset "CUDNN BatchNorm" begin
|
||||
@testset "4D Input" begin
|
||||
|
@ -8,16 +8,18 @@ trainmode(f, x...) = forward(f, x...)[1]
|
|||
cx = gpu(x)
|
||||
cm = gpu(m)
|
||||
|
||||
y = trainmode(m, x)
|
||||
cy = trainmode(cm, cx)
|
||||
y, back = forward((m, x) -> m(x), m, x)
|
||||
cy, cback = forward((m, x) -> m(x), cm, cx)
|
||||
|
||||
@test cpu(cy) ≈ y
|
||||
|
||||
g = gradient(()->sum(m(x)), params(m))
|
||||
cg = gradient(()->sum(cm(cx)), params(cm))
|
||||
Δ = randn(size(y))
|
||||
dm, dx = back(Δ)
|
||||
cdm, cdx = cback(gpu(Δ))
|
||||
|
||||
@test g[m.γ] ≈ cpu(cg[cm.γ])
|
||||
@test g[m.β] ≈ cpu(cg[cm.β])
|
||||
@test dm[].γ ≈ cpu(cdm[].γ)
|
||||
@test dm[].β ≈ cpu(cdm[].β)
|
||||
@test dx ≈ cpu(cdx)
|
||||
end
|
||||
|
||||
@testset "2D Input" begin
|
||||
|
@ -26,17 +28,17 @@ trainmode(f, x...) = forward(f, x...)[1]
|
|||
cx = gpu(x)
|
||||
cm = gpu(m)
|
||||
|
||||
y = trainmode(m, x)
|
||||
cy = trainmode(cm, cx)
|
||||
|
||||
@test cy isa CuArray{Float32,2}
|
||||
y, back = forward((m, x) -> m(x), m, x)
|
||||
cy, cback = forward((m, x) -> m(x), cm, cx)
|
||||
|
||||
@test cpu(cy) ≈ y
|
||||
|
||||
g = gradient(()->sum(m(x)), params(m))
|
||||
cg = gradient(()->sum(cm(cx)), params(cm))
|
||||
Δ = randn(size(y))
|
||||
dm, dx = back(Δ)
|
||||
cdm, cdx = cback(gpu(Δ))
|
||||
|
||||
@test g[m.γ] ≈ cpu(cg[cm.γ])
|
||||
@test g[m.β] ≈ cpu(cg[cm.β])
|
||||
@test dm[].γ ≈ cpu(cdm[].γ)
|
||||
@test dm[].β ≈ cpu(cdm[].β)
|
||||
@test dx ≈ cpu(cdx)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,46 +1,54 @@
|
|||
using Flux, CuArrays, Test
|
||||
using Flux: forward
|
||||
|
||||
@testset "RNN" begin
|
||||
@testset for R in [RNN, GRU, LSTM]
|
||||
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
|
||||
rnn = R(10, 5)
|
||||
curnn = mapleaves(gpu, rnn)
|
||||
@testset for batch_size in (1, 5)
|
||||
Flux.reset!(rnn)
|
||||
Flux.reset!(curnn)
|
||||
x = batch_size == 1 ?
|
||||
rand(10) :
|
||||
rand(10, batch_size)
|
||||
cux = gpu(x)
|
||||
y = (rnn(x); rnn(x))
|
||||
cuy = (curnn(cux); curnn(cux))
|
||||
|
||||
@test y ≈ collect(cuy)
|
||||
@test haskey(Flux.CUDA.descs, curnn.cell)
|
||||
Flux.reset!(rnn)
|
||||
Flux.reset!(curnn)
|
||||
x = batch_size == 1 ?
|
||||
rand(10) :
|
||||
rand(10, batch_size)
|
||||
cux = gpu(x)
|
||||
|
||||
#Δ = randn(size(y))
|
||||
y, back = forward((r, x) -> (r(x)), rnn, x)
|
||||
cuy, cuback = forward((r, x) -> (r(x)), curnn, cux)
|
||||
|
||||
#Flux.back!(y, Δ)
|
||||
#Flux.back!(cuy, gpu(Δ))
|
||||
@test y ≈ collect(cuy)
|
||||
@test haskey(Flux.CUDA.descs, curnn.cell)
|
||||
|
||||
@test x ≈ collect(cux)
|
||||
@test rnn.cell.Wi ≈ collect(curnn.cell.Wi)
|
||||
@test rnn.cell.Wh ≈ collect(curnn.cell.Wh)
|
||||
@test rnn.cell.b ≈ collect(curnn.cell.b)
|
||||
@test rnn.cell.h ≈ collect(curnn.cell.h)
|
||||
if isdefined(rnn.cell, :c)
|
||||
@test rnn.cell.c ≈ collect(curnn.cell.c)
|
||||
ȳ = randn(size(y))
|
||||
m̄, x̄ = back(ȳ)
|
||||
cum̄, cux̄ = cuback(gpu(ȳ))
|
||||
|
||||
m̄[].cell[].Wi
|
||||
|
||||
m̄[].state
|
||||
cum̄[].state
|
||||
|
||||
@test x̄ ≈ collect(cux̄)
|
||||
@test m̄[].cell[].Wi ≈ collect(cum̄[].cell[].Wi)
|
||||
@test m̄[].cell[].Wh ≈ collect(cum̄[].cell[].Wh)
|
||||
@test m̄[].cell[].b ≈ collect(cum̄[].cell[].b)
|
||||
if m̄[].state isa Tuple
|
||||
for (x, cx) in zip(m̄[].state, cum̄[].state)
|
||||
@test x ≈ collect(cx)
|
||||
end
|
||||
|
||||
Flux.reset!(rnn)
|
||||
Flux.reset!(curnn)
|
||||
ohx = batch_size == 1 ?
|
||||
Flux.onehot(rand(1:10), 1:10) :
|
||||
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
||||
cuohx = gpu(ohx)
|
||||
y = (rnn(ohx); rnn(ohx))
|
||||
cuy = (curnn(cuohx); curnn(cuohx))
|
||||
|
||||
@test y ≈ collect(cuy)
|
||||
else
|
||||
@test m̄[].state ≈ collect(cum̄[].state)
|
||||
end
|
||||
|
||||
Flux.reset!(rnn)
|
||||
Flux.reset!(curnn)
|
||||
ohx = batch_size == 1 ?
|
||||
Flux.onehot(rand(1:10), 1:10) :
|
||||
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
||||
cuohx = gpu(ohx)
|
||||
y = (rnn(ohx); rnn(ohx))
|
||||
cuy = (curnn(cuohx); curnn(cuohx))
|
||||
|
||||
@test y ≈ collect(cuy)
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue