update cudnn rnn

This commit is contained in:
Mike Innes 2018-07-10 18:16:37 +01:00
parent 70b5efeb4e
commit 10a169bb77
4 changed files with 36 additions and 61 deletions

View File

@ -286,41 +286,28 @@ function desc(rnn)
return d
end
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast
mutable struct RNNCall{R}
rnn::R
reserve::CuVector{UInt8}
RNNCall{R}(rnn::R) where R = new(rnn)
end
RNNCall(rnn) = RNNCall{typeof(rnn)}(rnn)
function (c::RNNCall)(args...)
rs, result = forwardTrain(desc(c.rnn), args...)
c.reserve = rs
return result
end
import Flux.Tracker
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(RNNCall(m), x, h) :
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(RNNCall(m), x, h) :
track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h)
return result[2], result[1]
end
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ?
track(RNNCall(m), x, h[1], h[2]) :
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
forward(desc(m), x, h[1], h[2])
return (result[2], result[3]), result[1]
end
@ -329,44 +316,29 @@ end
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
function accum_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src)
I = @cuindex dst
dst[I...] += src[reverse(I)...]
return
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data(x), data(h))
result, function (Δ)
y, ho = result
dy, dho = Δ
h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db))
end
blk, thr = cudims(dst)
@cuda (blk, thr) kernel(dst, src)
return dst
end
# function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h)
# y, ho = y_
# dy, dho = Δ
# h_ = hBatch(x, data(h))
# dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve)
# @back(x, dx)
# @back(h, unbroadcast(h, dh))
# (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
# # We don't have to make this assumption, it's just slightly more complex.
# @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
# istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
# istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
# istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
# end
# function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c)
# y, ho, co = y_
# dy, dho, dco = Δ
# h_ = hBatch(x, data(h))
# c_ = hBatch(x, data(c))
# dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve)
# @back(x, dx)
# @back(h, unbroadcast(h, dh))
# @back(c, unbroadcast(h, dc))
# (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
# @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
# istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
# istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
# istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
# end
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
result, function (Δ)
y, ho = result
dy, dho, dco = Δ
h_ = hBatch(x, data(h))
c_ = hBatch(x, data(c))
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN,
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
dWi.', dWh.', db))
end
end

View File

@ -55,7 +55,8 @@ macro grad(ex)
@capture(shortdef(ex), (name_(args__) = body_) |
(name_(args__) where {T__} = body_)) || error("Need a function definition")
T == nothing && (T = [])
insert!(args, 1+isexpr(args[1], :parameters) , :(::typeof($name)))
isexpr(name, :(::)) || (name = :(::typeof($name)))
insert!(args, 1+isexpr(args[1], :parameters) , name)
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
end

View File

@ -66,7 +66,7 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y)
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
@grad function getindex(xs, i...)
@grad function getindex(xs::AbstractArray, i...)
data(xs)[i...], function (Δ)
Δ′ = zero(xs)
Δ′[i...] = data(Δ)

View File

@ -79,13 +79,14 @@ struct TrackedTuple{T<:Tuple}
tracker::Tracked{T}
end
data(xs::TrackedTuple) = xs.data
tracker(xs::TrackedTuple) = xs.tracker
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
init_grad(x::Tuple) = init_grad.(x)
zero_grad!(x::Tuple) = zero_grad!.(x)
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f))
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
function Base.show(io::IO, xs::TrackedTuple)
show(io, data(xs))
@ -96,8 +97,9 @@ Base.length(x::TrackedTuple) = length(data(x))
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
back(::typeof(getindex), Δ, t, i) =
back(t, ntuple(j -> i == j ? Δ : 0, length(t)))
@grad function getindex(xs::TrackedTuple, i)
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
end
# Array collection