cudnn rnns + implicit gradients
This commit is contained in:
parent
fe57215b7e
commit
b348b20452
|
@ -269,7 +269,8 @@ function desc(rnn)
|
|||
return d
|
||||
end
|
||||
|
||||
using ..Flux: @adjoint
|
||||
import Zygote
|
||||
using Zygote: @adjoint
|
||||
|
||||
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
y, h′ = forward(desc(m), x, h)
|
||||
|
@ -301,6 +302,15 @@ coerce_cuda(x::Union{CuArray,Nothing}) = x
|
|||
|
||||
coerce_cuda(x) = x .+ CuArrays.fill(0)
|
||||
|
||||
function struct_grad!(cx::Zygote.Context, x, x̄)
|
||||
for f in fieldnames(typeof(x))
|
||||
Zygote.accum_param(cx, getfield(x, f), getfield(x̄, f))
|
||||
end
|
||||
dx = Zygote.grad_mut(cx, x)
|
||||
dx[] = Zygote.accum(dx[], x̄)
|
||||
return dx
|
||||
end
|
||||
|
||||
for RNN in (CuRNN, CuGRU)
|
||||
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
reserve, (y, ho) = forwardTrain(desc(m), x, h)
|
||||
|
@ -309,7 +319,7 @@ for RNN in (CuRNN, CuGRU)
|
|||
h_ = hBatch(x, h)
|
||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||
dm = Ref{Any}((σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
|
||||
dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
|
||||
(dm, unbroadcast(h, dh), dx)
|
||||
end
|
||||
end
|
||||
|
@ -324,7 +334,7 @@ end
|
|||
c_ = hBatch(x, c)
|
||||
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||
dm = Ref{Any}((Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
|
||||
dm = struct_grad!(__context__, m, (Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
|
||||
(dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -4,7 +4,10 @@ using Flux: forward
|
|||
@testset for R in [RNN, GRU, LSTM]
|
||||
m = R(10, 5) |> gpu
|
||||
x = gpu(rand(10))
|
||||
@test gradient(m -> sum(m(x)), m) isa Tuple
|
||||
(m̄,) = gradient(m -> sum(m(x)), m)
|
||||
Flux.reset!(m)
|
||||
θ = gradient(() -> sum(m(x)), params(m))
|
||||
@test collect(m̄[].cell[].Wi) == collect(θ[m.cell.Wi])
|
||||
end
|
||||
|
||||
@testset "RNN" begin
|
||||
|
|
Loading…
Reference in New Issue