cudnn rnns + implicit gradients

This commit is contained in:
Mike Innes 2019-09-17 15:41:42 +01:00
parent fe57215b7e
commit b348b20452
2 changed files with 17 additions and 4 deletions

View File

@ -269,7 +269,8 @@ function desc(rnn)
return d
end
using ..Flux: @adjoint
import Zygote
using Zygote: @adjoint
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
y, h = forward(desc(m), x, h)
@ -301,6 +302,15 @@ coerce_cuda(x::Union{CuArray,Nothing}) = x
coerce_cuda(x) = x .+ CuArrays.fill(0)
function struct_grad!(cx::Zygote.Context, x, )
for f in fieldnames(typeof(x))
Zygote.accum_param(cx, getfield(x, f), getfield(, f))
end
dx = Zygote.grad_mut(cx, x)
dx[] = Zygote.accum(dx[], )
return dx
end
for RNN in (CuRNN, CuGRU)
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
reserve, (y, ho) = forwardTrain(desc(m), x, h)
@ -309,7 +319,7 @@ for RNN in (CuRNN, CuGRU)
h_ = hBatch(x, h)
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
dm = Ref{Any}((σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
(dm, unbroadcast(h, dh), dx)
end
end
@ -324,7 +334,7 @@ end
c_ = hBatch(x, c)
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
dm = Ref{Any}((Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
dm = struct_grad!(__context__, m, (Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
(dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
end
end

View File

@ -4,7 +4,10 @@ using Flux: forward
@testset for R in [RNN, GRU, LSTM]
m = R(10, 5) |> gpu
x = gpu(rand(10))
@test gradient(m -> sum(m(x)), m) isa Tuple
(,) = gradient(m -> sum(m(x)), m)
Flux.reset!(m)
θ = gradient(() -> sum(m(x)), params(m))
@test collect([].cell[].Wi) == collect(θ[m.cell.Wi])
end
@testset "RNN" begin