cudnn rnns + implicit gradients
This commit is contained in:
parent
fe57215b7e
commit
b348b20452
@ -269,7 +269,8 @@ function desc(rnn)
|
|||||||
return d
|
return d
|
||||||
end
|
end
|
||||||
|
|
||||||
using ..Flux: @adjoint
|
import Zygote
|
||||||
|
using Zygote: @adjoint
|
||||||
|
|
||||||
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
y, h′ = forward(desc(m), x, h)
|
y, h′ = forward(desc(m), x, h)
|
||||||
@ -301,6 +302,15 @@ coerce_cuda(x::Union{CuArray,Nothing}) = x
|
|||||||
|
|
||||||
coerce_cuda(x) = x .+ CuArrays.fill(0)
|
coerce_cuda(x) = x .+ CuArrays.fill(0)
|
||||||
|
|
||||||
|
function struct_grad!(cx::Zygote.Context, x, x̄)
|
||||||
|
for f in fieldnames(typeof(x))
|
||||||
|
Zygote.accum_param(cx, getfield(x, f), getfield(x̄, f))
|
||||||
|
end
|
||||||
|
dx = Zygote.grad_mut(cx, x)
|
||||||
|
dx[] = Zygote.accum(dx[], x̄)
|
||||||
|
return dx
|
||||||
|
end
|
||||||
|
|
||||||
for RNN in (CuRNN, CuGRU)
|
for RNN in (CuRNN, CuGRU)
|
||||||
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
reserve, (y, ho) = forwardTrain(desc(m), x, h)
|
reserve, (y, ho) = forwardTrain(desc(m), x, h)
|
||||||
@ -309,7 +319,7 @@ for RNN in (CuRNN, CuGRU)
|
|||||||
h_ = hBatch(x, h)
|
h_ = hBatch(x, h)
|
||||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||||
dm = Ref{Any}((σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
|
dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
|
||||||
(dm, unbroadcast(h, dh), dx)
|
(dm, unbroadcast(h, dh), dx)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -324,7 +334,7 @@ end
|
|||||||
c_ = hBatch(x, c)
|
c_ = hBatch(x, c)
|
||||||
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||||
dm = Ref{Any}((Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
|
dm = struct_grad!(__context__, m, (Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
|
||||||
(dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
|
(dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -4,7 +4,10 @@ using Flux: forward
|
|||||||
@testset for R in [RNN, GRU, LSTM]
|
@testset for R in [RNN, GRU, LSTM]
|
||||||
m = R(10, 5) |> gpu
|
m = R(10, 5) |> gpu
|
||||||
x = gpu(rand(10))
|
x = gpu(rand(10))
|
||||||
@test gradient(m -> sum(m(x)), m) isa Tuple
|
(m̄,) = gradient(m -> sum(m(x)), m)
|
||||||
|
Flux.reset!(m)
|
||||||
|
θ = gradient(() -> sum(m(x)), params(m))
|
||||||
|
@test collect(m̄[].cell[].Wi) == collect(θ[m.cell.Wi])
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "RNN" begin
|
@testset "RNN" begin
|
||||||
|
Loading…
Reference in New Issue
Block a user