fix cuda code and tests
This commit is contained in:
parent
62ec01a6f5
commit
487000ac31
@ -268,48 +268,55 @@ end
|
|||||||
using ..Flux: @adjoint
|
using ..Flux: @adjoint
|
||||||
|
|
||||||
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
result = forward(desc(m), x, h)
|
y, h′ = forward(desc(m), x, h)
|
||||||
return result[2], result[1]
|
return h′, y
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
result = forward(desc(m), x, h)
|
y, h′ = forward(desc(m), x, h)
|
||||||
return result[2], result[1]
|
return h′, y
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
result = forward(desc(m), x, h[1], h[2])
|
y, h′, c′ = forward(desc(m), x, h[1], h[2])
|
||||||
return (result[2], result[3]), result[1]
|
return (h′, c′), y
|
||||||
end
|
end
|
||||||
|
|
||||||
(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
|
|
||||||
|
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||||
|
|
||||||
|
unbroadcast(x::AbstractArray, Δ) =
|
||||||
|
size(x) == size(Δ) ? Δ :
|
||||||
|
length(x) == length(Δ) ? trim(x, Δ) :
|
||||||
|
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||||
|
|
||||||
for RNN in (CuRNN, CuGRU)
|
for RNN in (CuRNN, CuGRU)
|
||||||
@eval @adjoint function (m::$RNN)(x, h, Wi, Wh, b)
|
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
reserve, result = forwardTrain(desc(m), x, h)
|
reserve, (y, ho) = forwardTrain(desc(m), x, h)
|
||||||
result, function (Δ)
|
(ho, y), function (Δ)
|
||||||
y, ho = result
|
dho, dy = Δ
|
||||||
dy, dho = Δ
|
|
||||||
h_ = hBatch(x, h)
|
h_ = hBatch(x, h)
|
||||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||||
(dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)
|
dm = Ref{Any}((σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
|
||||||
|
(dm, unbroadcast(h, dh), dx)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@adjoint function (m::CuLSTM)(x, h, c, Wi, Wh, b)
|
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
reserve, result = forwardTrain(desc(m), x, h, c)
|
reserve, (y, ho, co) = forwardTrain(desc(m), x, h, c)
|
||||||
result, function (Δ)
|
((ho, co), y), function (Δ)
|
||||||
y, ho = result
|
dhc, dy = Δ
|
||||||
dy, dho, dco = Δ
|
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
|
||||||
h_ = hBatch(x, h)
|
h_ = hBatch(x, h)
|
||||||
c_ = hBatch(x, c)
|
c_ = hBatch(x, c)
|
||||||
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||||
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
|
dm = Ref{Any}((Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
|
||||||
transpose(dWi), transpose(dWh), db)
|
(dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
using Flux, CuArrays, Test
|
using Flux, CuArrays, Test
|
||||||
trainmode(f, x...) = forward(f, x...)[1]
|
using Flux: forward
|
||||||
|
|
||||||
@testset "CUDNN BatchNorm" begin
|
@testset "CUDNN BatchNorm" begin
|
||||||
@testset "4D Input" begin
|
@testset "4D Input" begin
|
||||||
@ -8,16 +8,18 @@ trainmode(f, x...) = forward(f, x...)[1]
|
|||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
cm = gpu(m)
|
cm = gpu(m)
|
||||||
|
|
||||||
y = trainmode(m, x)
|
y, back = forward((m, x) -> m(x), m, x)
|
||||||
cy = trainmode(cm, cx)
|
cy, cback = forward((m, x) -> m(x), cm, cx)
|
||||||
|
|
||||||
@test cpu(cy) ≈ y
|
@test cpu(cy) ≈ y
|
||||||
|
|
||||||
g = gradient(()->sum(m(x)), params(m))
|
Δ = randn(size(y))
|
||||||
cg = gradient(()->sum(cm(cx)), params(cm))
|
dm, dx = back(Δ)
|
||||||
|
cdm, cdx = cback(gpu(Δ))
|
||||||
|
|
||||||
@test g[m.γ] ≈ cpu(cg[cm.γ])
|
@test dm[].γ ≈ cpu(cdm[].γ)
|
||||||
@test g[m.β] ≈ cpu(cg[cm.β])
|
@test dm[].β ≈ cpu(cdm[].β)
|
||||||
|
@test dx ≈ cpu(cdx)
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "2D Input" begin
|
@testset "2D Input" begin
|
||||||
@ -26,17 +28,17 @@ trainmode(f, x...) = forward(f, x...)[1]
|
|||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
cm = gpu(m)
|
cm = gpu(m)
|
||||||
|
|
||||||
y = trainmode(m, x)
|
y, back = forward((m, x) -> m(x), m, x)
|
||||||
cy = trainmode(cm, cx)
|
cy, cback = forward((m, x) -> m(x), cm, cx)
|
||||||
|
|
||||||
@test cy isa CuArray{Float32,2}
|
|
||||||
|
|
||||||
@test cpu(cy) ≈ y
|
@test cpu(cy) ≈ y
|
||||||
|
|
||||||
g = gradient(()->sum(m(x)), params(m))
|
Δ = randn(size(y))
|
||||||
cg = gradient(()->sum(cm(cx)), params(cm))
|
dm, dx = back(Δ)
|
||||||
|
cdm, cdx = cback(gpu(Δ))
|
||||||
|
|
||||||
@test g[m.γ] ≈ cpu(cg[cm.γ])
|
@test dm[].γ ≈ cpu(cdm[].γ)
|
||||||
@test g[m.β] ≈ cpu(cg[cm.β])
|
@test dm[].β ≈ cpu(cdm[].β)
|
||||||
|
@test dx ≈ cpu(cdx)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -1,46 +1,54 @@
|
|||||||
using Flux, CuArrays, Test
|
using Flux, CuArrays, Test
|
||||||
|
using Flux: forward
|
||||||
|
|
||||||
@testset "RNN" begin
|
@testset "RNN" begin
|
||||||
@testset for R in [RNN, GRU, LSTM]
|
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
|
||||||
rnn = R(10, 5)
|
rnn = R(10, 5)
|
||||||
curnn = mapleaves(gpu, rnn)
|
curnn = mapleaves(gpu, rnn)
|
||||||
@testset for batch_size in (1, 5)
|
|
||||||
Flux.reset!(rnn)
|
|
||||||
Flux.reset!(curnn)
|
|
||||||
x = batch_size == 1 ?
|
|
||||||
rand(10) :
|
|
||||||
rand(10, batch_size)
|
|
||||||
cux = gpu(x)
|
|
||||||
y = (rnn(x); rnn(x))
|
|
||||||
cuy = (curnn(cux); curnn(cux))
|
|
||||||
|
|
||||||
@test y ≈ collect(cuy)
|
Flux.reset!(rnn)
|
||||||
@test haskey(Flux.CUDA.descs, curnn.cell)
|
Flux.reset!(curnn)
|
||||||
|
x = batch_size == 1 ?
|
||||||
|
rand(10) :
|
||||||
|
rand(10, batch_size)
|
||||||
|
cux = gpu(x)
|
||||||
|
|
||||||
#Δ = randn(size(y))
|
y, back = forward((r, x) -> (r(x)), rnn, x)
|
||||||
|
cuy, cuback = forward((r, x) -> (r(x)), curnn, cux)
|
||||||
|
|
||||||
#Flux.back!(y, Δ)
|
@test y ≈ collect(cuy)
|
||||||
#Flux.back!(cuy, gpu(Δ))
|
@test haskey(Flux.CUDA.descs, curnn.cell)
|
||||||
|
|
||||||
@test x ≈ collect(cux)
|
ȳ = randn(size(y))
|
||||||
@test rnn.cell.Wi ≈ collect(curnn.cell.Wi)
|
m̄, x̄ = back(ȳ)
|
||||||
@test rnn.cell.Wh ≈ collect(curnn.cell.Wh)
|
cum̄, cux̄ = cuback(gpu(ȳ))
|
||||||
@test rnn.cell.b ≈ collect(curnn.cell.b)
|
|
||||||
@test rnn.cell.h ≈ collect(curnn.cell.h)
|
m̄[].cell[].Wi
|
||||||
if isdefined(rnn.cell, :c)
|
|
||||||
@test rnn.cell.c ≈ collect(curnn.cell.c)
|
m̄[].state
|
||||||
|
cum̄[].state
|
||||||
|
|
||||||
|
@test x̄ ≈ collect(cux̄)
|
||||||
|
@test m̄[].cell[].Wi ≈ collect(cum̄[].cell[].Wi)
|
||||||
|
@test m̄[].cell[].Wh ≈ collect(cum̄[].cell[].Wh)
|
||||||
|
@test m̄[].cell[].b ≈ collect(cum̄[].cell[].b)
|
||||||
|
if m̄[].state isa Tuple
|
||||||
|
for (x, cx) in zip(m̄[].state, cum̄[].state)
|
||||||
|
@test x ≈ collect(cx)
|
||||||
end
|
end
|
||||||
|
else
|
||||||
Flux.reset!(rnn)
|
@test m̄[].state ≈ collect(cum̄[].state)
|
||||||
Flux.reset!(curnn)
|
|
||||||
ohx = batch_size == 1 ?
|
|
||||||
Flux.onehot(rand(1:10), 1:10) :
|
|
||||||
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
|
||||||
cuohx = gpu(ohx)
|
|
||||||
y = (rnn(ohx); rnn(ohx))
|
|
||||||
cuy = (curnn(cuohx); curnn(cuohx))
|
|
||||||
|
|
||||||
@test y ≈ collect(cuy)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Flux.reset!(rnn)
|
||||||
|
Flux.reset!(curnn)
|
||||||
|
ohx = batch_size == 1 ?
|
||||||
|
Flux.onehot(rand(1:10), 1:10) :
|
||||||
|
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
||||||
|
cuohx = gpu(ohx)
|
||||||
|
y = (rnn(ohx); rnn(ohx))
|
||||||
|
cuy = (curnn(cuohx); curnn(cuohx))
|
||||||
|
|
||||||
|
@test y ≈ collect(cuy)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user