test fillarray gradients
This commit is contained in:
parent
29eae312b8
commit
fe57215b7e
@ -297,11 +297,15 @@ unbroadcast(x::AbstractArray, Δ) =
|
|||||||
length(x) == length(Δ) ? trim(x, Δ) :
|
length(x) == length(Δ) ? trim(x, Δ) :
|
||||||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||||
|
|
||||||
|
coerce_cuda(x::Union{CuArray,Nothing}) = x
|
||||||
|
|
||||||
|
coerce_cuda(x) = x .+ CuArrays.fill(0)
|
||||||
|
|
||||||
for RNN in (CuRNN, CuGRU)
|
for RNN in (CuRNN, CuGRU)
|
||||||
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
reserve, (y, ho) = forwardTrain(desc(m), x, h)
|
reserve, (y, ho) = forwardTrain(desc(m), x, h)
|
||||||
(ho, y), function (Δ)
|
(ho, y), function (Δ)
|
||||||
dho, dy = Δ
|
dho, dy = coerce_cuda.(Δ)
|
||||||
h_ = hBatch(x, h)
|
h_ = hBatch(x, h)
|
||||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||||
@ -314,7 +318,7 @@ end
|
|||||||
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
reserve, (y, ho, co) = forwardTrain(desc(m), x, h, c)
|
reserve, (y, ho, co) = forwardTrain(desc(m), x, h, c)
|
||||||
((ho, co), y), function (Δ)
|
((ho, co), y), function (Δ)
|
||||||
dhc, dy = Δ
|
dhc, dy = coerce_cuda.(Δ)
|
||||||
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
|
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
|
||||||
h_ = hBatch(x, h)
|
h_ = hBatch(x, h)
|
||||||
c_ = hBatch(x, c)
|
c_ = hBatch(x, c)
|
||||||
|
@ -1,6 +1,12 @@
|
|||||||
using Flux, CuArrays, Test
|
using Flux, CuArrays, Test
|
||||||
using Flux: forward
|
using Flux: forward
|
||||||
|
|
||||||
|
@testset for R in [RNN, GRU, LSTM]
|
||||||
|
m = R(10, 5) |> gpu
|
||||||
|
x = gpu(rand(10))
|
||||||
|
@test gradient(m -> sum(m(x)), m) isa Tuple
|
||||||
|
end
|
||||||
|
|
||||||
@testset "RNN" begin
|
@testset "RNN" begin
|
||||||
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
|
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
|
||||||
rnn = R(10, 5)
|
rnn = R(10, 5)
|
||||||
|
Loading…
Reference in New Issue
Block a user