hook up backward passes
This commit is contained in:
parent
a1d1930097
commit
b8f148b012
@ -105,7 +105,7 @@ getworkspace(bytes) =
|
|||||||
(workspace[] = CuVector{UInt8}(bytes))
|
(workspace[] = CuVector{UInt8}(bytes))
|
||||||
|
|
||||||
getworkspace(r::RNNDesc, seqlen, xdesc) =
|
getworkspace(r::RNNDesc, seqlen, xdesc) =
|
||||||
getworkspace(rnnWorkspaceSize(rnn, seqlen, xdesc))
|
getworkspace(rnnWorkspaceSize(r, seqlen, xdesc))
|
||||||
|
|
||||||
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
||||||
size = Csize_t[0]
|
size = Csize_t[0]
|
||||||
@ -145,6 +145,7 @@ end
|
|||||||
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
||||||
|
|
||||||
hDesc(h::Void) = C_NULL, C_NULL
|
hDesc(h::Void) = C_NULL, C_NULL
|
||||||
|
hDesc(x::Integer) = (@assert x == 0; hDesc(nothing))
|
||||||
function hDesc(h::CuArray)
|
function hDesc(h::CuArray)
|
||||||
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
||||||
end
|
end
|
||||||
@ -272,19 +273,72 @@ function desc(rnn)
|
|||||||
return d
|
return d
|
||||||
end
|
end
|
||||||
|
|
||||||
|
import Flux.Tracker: data, isleaf, istracked, track, back_, @back
|
||||||
|
|
||||||
|
struct RNNCall{R}
|
||||||
|
rnn::R
|
||||||
|
end
|
||||||
|
|
||||||
|
(c::RNNCall)(args...) = forward(desc(c.rnn), args..., train = true)
|
||||||
|
|
||||||
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
||||||
|
|
||||||
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||||
y, h = forward(desc(m), Flux.data(x), Flux.data(h), train = istrain(m, h, x))
|
result = istrain(m, h, x) ?
|
||||||
return h, y
|
track(RNNCall(m), x, h) :
|
||||||
|
forward(desc(m), x, h)
|
||||||
|
return result[2], result[1]
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||||
y, h = forward(desc(m), Flux.data(x), Flux.data(h), train = istrain(m, h, x))
|
result = istrain(m, h, x) ?
|
||||||
return h, y
|
track(RNNCall(m), x, h) :
|
||||||
|
forward(desc(m), x, h)
|
||||||
|
return result[2], result[1]
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||||
y, h, c = forward(desc(m), Flux.data(x), Flux.data.(h)..., train = istrain(m, h, x))
|
result = istrain(m, h, x) ?
|
||||||
return (h, c), y
|
track(RNNCall(m), x, h[1], h[2]) :
|
||||||
|
forward(desc(m), x, h[1], h[2])
|
||||||
|
return (result[2], result[3]), result[1]
|
||||||
|
end
|
||||||
|
|
||||||
|
function accum_transpose!(dst::CuArray, src::CuArray)
|
||||||
|
function kernel(dst, src)
|
||||||
|
I = @cuindex dst
|
||||||
|
dst[I...] += src[reverse(I)...]
|
||||||
|
return
|
||||||
|
end
|
||||||
|
blk, thr = cudims(dst)
|
||||||
|
@cuda (blk, thr) kernel(dst, src)
|
||||||
|
return dst
|
||||||
|
end
|
||||||
|
|
||||||
|
function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h)
|
||||||
|
y, ho = y_
|
||||||
|
dy, dho = Δ
|
||||||
|
dx, dh = backwardData(descs[m.rnn], y, dy, dho, data(h))
|
||||||
|
@back(x, dx)
|
||||||
|
@back(h, dh)
|
||||||
|
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), data(h), y)
|
||||||
|
# We don't have to make this assumption, it's just slightly more complex.
|
||||||
|
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||||
|
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||||
|
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
||||||
|
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
|
||||||
|
end
|
||||||
|
|
||||||
|
function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c)
|
||||||
|
y, ho, co = y_
|
||||||
|
dy, dho, dco = Δ
|
||||||
|
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, data(h), data(c))
|
||||||
|
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), data(h), y)
|
||||||
|
@back(x, dx)
|
||||||
|
@back(h, dh)
|
||||||
|
@back(c, dc)
|
||||||
|
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||||
|
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||||
|
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
||||||
|
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user