Merge branch 'master' into tb/cuarrays_dnn

This commit is contained in:
Mike Innes 2019-09-17 16:17:09 +01:00
commit 5baebf48f4
4 changed files with 41 additions and 17 deletions

View File

@ -58,9 +58,9 @@ version = "3.1.0"
[[CUDAnative]] [[CUDAnative]]
deps = ["Adapt", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Printf", "TimerOutputs"] deps = ["Adapt", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Printf", "TimerOutputs"]
git-tree-sha1 = "0a00bef482b7c9127495c7f4a2a85e73b13b5af8" git-tree-sha1 = "52ae1ce10ebfa686e227655c47b19add89308623"
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17" uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
version = "2.3.0" version = "2.3.1"
[[CodecZlib]] [[CodecZlib]]
deps = ["BinaryProvider", "Libdl", "TranscodingStreams"] deps = ["BinaryProvider", "Libdl", "TranscodingStreams"]
@ -106,11 +106,11 @@ version = "4.0.0"
[[CuArrays]] [[CuArrays]]
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"] deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
git-tree-sha1 = "8189fcb50b24998bad7518e52443fdb542403093" git-tree-sha1 = "155349d2c40568a23cbc4599f0e17e2fdf1bbbcc"
repo-rev = "tb/flux" repo-rev = "tb/flux"
repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git" repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git"
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
version = "1.2.1" version = "1.3.0"
[[DataAPI]] [[DataAPI]]
git-tree-sha1 = "8903f0219d3472543fc4b2f5ebaf675a07f817c0" git-tree-sha1 = "8903f0219d3472543fc4b2f5ebaf675a07f817c0"
@ -149,9 +149,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[FFTW]] [[FFTW]]
deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"] deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"]
git-tree-sha1 = "e1a479d3c972f20c9a70563eec740bbfc786f515" git-tree-sha1 = "03f8776fbdae28c20c0d1d2ae4e090cd1dfcd247"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "0.3.0" version = "1.0.0"
[[FillArrays]] [[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"] deps = ["LinearAlgebra", "Random", "SparseArrays"]
@ -172,9 +172,9 @@ version = "0.10.3"
[[GPUArrays]] [[GPUArrays]]
deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "StaticArrays", "Test"] deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "StaticArrays", "Test"]
git-tree-sha1 = "dd169c636d1d3656a9faca772f5bd7c226a61254" git-tree-sha1 = "b5009ac44b141ded5e6f04c4db83807970f56e91"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "1.0.1" version = "1.0.2"
[[IRTools]] [[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"] deps = ["InteractiveUtils", "MacroTools", "Test"]
@ -200,9 +200,9 @@ version = "0.7.2"
[[LLVM]] [[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"] deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "52cfea426bd248a427aace7d88eb5d45b84ea297" git-tree-sha1 = "4a05f742837779a00bd8c9a18da6817367c4245d"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0" uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "1.2.0" version = "1.3.0"
[[LibGit2]] [[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
@ -390,7 +390,7 @@ version = "0.8.3"
[[Zygote]] [[Zygote]]
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "9186cb0b3b59219e4aba0840614d6a9d7282012e" git-tree-sha1 = "ce6d7142d665b1e4c71c678fa7db4da3bbc6743f"
repo-rev = "master" repo-rev = "master"
repo-url = "https://github.com/FluxML/Zygote.jl.git" repo-url = "https://github.com/FluxML/Zygote.jl.git"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"

View File

@ -5,7 +5,7 @@ Consider a [simple linear regression](../models/basics.md). We create some dummy
```julia ```julia
using Flux using Flux
W = rand(2, 5)) W = rand(2, 5)
b = rand(2) b = rand(2)
predict(x) = (W * x) .+ b predict(x) = (W * x) .+ b

View File

@ -44,7 +44,8 @@ function desc(rnn)
return d return d
end end
using ..Flux: @adjoint import Zygote
using Zygote: @adjoint
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
y, h = CUDNN.forward(desc(m), x, h) y, h = CUDNN.forward(desc(m), x, h)
@ -72,15 +73,29 @@ unbroadcast(x::AbstractArray, Δ) =
length(x) == length(Δ) ? trim(x, Δ) : length(x) == length(Δ) ? trim(x, Δ) :
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
coerce_cuda(x::Union{CuArray,Nothing}) = x
coerce_cuda(x::Tuple) = coerce_cuda.(x)
coerce_cuda(x) = x .+ CuArrays.fill(0)
function struct_grad!(cx::Zygote.Context, x, )
for f in fieldnames(typeof(x))
Zygote.accum_param(cx, getfield(x, f), getfield(, f))
end
dx = Zygote.grad_mut(cx, x)
dx[] = Zygote.accum(dx[], )
return dx
end
for RNN in (CuRNN, CuGRU) for RNN in (CuRNN, CuGRU)
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} @eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
reserve, (y, ho) = CUDNN.forwardTrain(desc(m), x, h) reserve, (y, ho) = CUDNN.forwardTrain(desc(m), x, h)
(ho, y), function (Δ) (ho, y), function (Δ)
dho, dy = Δ dho, dy = coerce_cuda(Δ)
h_ = CUDNN.hBatch(x, h) h_ = CUDNN.hBatch(x, h)
dx, dh = CUDNN.backwardData(descs[m], y, dy, dho, h_, reserve) dx, dh = CUDNN.backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve) (dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve)
dm = Ref{Any}((σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing)) dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing))
(dm, unbroadcast(h, dh), dx) (dm, unbroadcast(h, dh), dx)
end end
end end
@ -89,13 +104,13 @@ end
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} @adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
reserve, (y, ho, co) = CUDNN.forwardTrain(desc(m), x, h, c) reserve, (y, ho, co) = CUDNN.forwardTrain(desc(m), x, h, c)
((ho, co), y), function (Δ) ((ho, co), y), function (Δ)
dhc, dy = Δ dhc, dy = coerce_cuda(Δ)
dho, dco = dhc === nothing ? (nothing, nothing) : dhc dho, dco = dhc === nothing ? (nothing, nothing) : dhc
h_ = CUDNN.hBatch(x, h) h_ = CUDNN.hBatch(x, h)
c_ = CUDNN.hBatch(x, c) c_ = CUDNN.hBatch(x, c)
dx, dh, dc = CUDNN.backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) dx, dh, dc = CUDNN.backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve) (dWi, dWh), db = CUDNN.backwardWeights(descs[m], x, h_, y, reserve)
dm = Ref{Any}((Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing)) dm = struct_grad!(__context__, m, (Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing))
(dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx) (dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx)
end end
end end

View File

@ -1,6 +1,15 @@
using Flux, CuArrays, Test using Flux, CuArrays, Test
using Flux: forward using Flux: forward
@testset for R in [RNN, GRU, LSTM]
m = R(10, 5) |> gpu
x = gpu(rand(10))
(,) = gradient(m -> sum(m(x)), m)
Flux.reset!(m)
θ = gradient(() -> sum(m(x)), params(m))
@test collect([].cell[].Wi) == collect(θ[m.cell.Wi])
end
@testset "RNN" begin @testset "RNN" begin
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5) @testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
rnn = R(10, 5) rnn = R(10, 5)