From a600a9ceed1c1d95619cd03da4a43eea1f3c2421 Mon Sep 17 00:00:00 2001 From: Naba7 Date: Sat, 14 Sep 2019 10:56:17 +0530 Subject: [PATCH 1/5] removed extra parenthesis --- docs/src/training/optimisers.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/training/optimisers.md b/docs/src/training/optimisers.md index 4a8d09cb..9eb659c4 100644 --- a/docs/src/training/optimisers.md +++ b/docs/src/training/optimisers.md @@ -5,7 +5,7 @@ Consider a [simple linear regression](../models/basics.md). We create some dummy ```julia using Flux -W = rand(2, 5)) +W = rand(2, 5) b = rand(2) predict(x) = (W * x) .+ b From fe57215b7e7e2be3b3543707201d29d08e1ad970 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Tue, 17 Sep 2019 15:21:03 +0100 Subject: [PATCH 2/5] test fillarray gradients --- src/cuda/curnn.jl | 8 ++++++-- test/cuda/curnn.jl | 6 ++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index ca8b5140..b989c771 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -297,11 +297,15 @@ unbroadcast(x::AbstractArray, Δ) = length(x) == length(Δ) ? trim(x, Δ) : trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) +coerce_cuda(x::Union{CuArray,Nothing}) = x + +coerce_cuda(x) = x .+ CuArrays.fill(0) + for RNN in (CuRNN, CuGRU) @eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} reserve, (y, ho) = forwardTrain(desc(m), x, h) (ho, y), function (Δ) - dho, dy = Δ + dho, dy = coerce_cuda.(Δ) h_ = hBatch(x, h) dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) (dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve) @@ -314,7 +318,7 @@ end @adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} reserve, (y, ho, co) = forwardTrain(desc(m), x, h, c) ((ho, co), y), function (Δ) - dhc, dy = Δ + dhc, dy = coerce_cuda.(Δ) dho, dco = dhc === nothing ? (nothing, nothing) : dhc h_ = hBatch(x, h) c_ = hBatch(x, c) diff --git a/test/cuda/curnn.jl b/test/cuda/curnn.jl index c1bc804e..49042514 100644 --- a/test/cuda/curnn.jl +++ b/test/cuda/curnn.jl @@ -1,6 +1,12 @@ using Flux, CuArrays, Test using Flux: forward +@testset for R in [RNN, GRU, LSTM] + m = R(10, 5) |> gpu + x = gpu(rand(10)) + @test gradient(m -> sum(m(x)), m) isa Tuple +end + @testset "RNN" begin @testset for R in [RNN, GRU, LSTM], batch_size in (1, 5) rnn = R(10, 5) From b348b204529c54a988bac87d7a0ee5fd6f8cdbb8 Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Tue, 17 Sep 2019 15:41:42 +0100 Subject: [PATCH 3/5] cudnn rnns + implicit gradients --- src/cuda/curnn.jl | 16 +++++++++++++--- test/cuda/curnn.jl | 5 ++++- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index b989c771..4e2a773b 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -269,7 +269,8 @@ function desc(rnn) return d end -using ..Flux: @adjoint +import Zygote +using Zygote: @adjoint function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} y, h′ = forward(desc(m), x, h) @@ -301,6 +302,15 @@ coerce_cuda(x::Union{CuArray,Nothing}) = x coerce_cuda(x) = x .+ CuArrays.fill(0) +function struct_grad!(cx::Zygote.Context, x, x̄) + for f in fieldnames(typeof(x)) + Zygote.accum_param(cx, getfield(x, f), getfield(x̄, f)) + end + dx = Zygote.grad_mut(cx, x) + dx[] = Zygote.accum(dx[], x̄) + return dx +end + for RNN in (CuRNN, CuGRU) @eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} reserve, (y, ho) = forwardTrain(desc(m), x, h) @@ -309,7 +319,7 @@ for RNN in (CuRNN, CuGRU) h_ = hBatch(x, h) dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) (dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve) - dm = Ref{Any}((σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing)) + dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing)) (dm, unbroadcast(h, dh), dx) end end @@ -324,7 +334,7 @@ end c_ = hBatch(x, c) dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) (dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve) - dm = Ref{Any}((Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing)) + dm = struct_grad!(__context__, m, (Wi=transpose(dWi),Wh=transpose(dWh),b=db,h=nothing,c=nothing)) (dm, (unbroadcast(h, dh), unbroadcast(c, dc)), dx) end end diff --git a/test/cuda/curnn.jl b/test/cuda/curnn.jl index 49042514..1e834d14 100644 --- a/test/cuda/curnn.jl +++ b/test/cuda/curnn.jl @@ -4,7 +4,10 @@ using Flux: forward @testset for R in [RNN, GRU, LSTM] m = R(10, 5) |> gpu x = gpu(rand(10)) - @test gradient(m -> sum(m(x)), m) isa Tuple + (m̄,) = gradient(m -> sum(m(x)), m) + Flux.reset!(m) + θ = gradient(() -> sum(m(x)), params(m)) + @test collect(m̄[].cell[].Wi) == collect(θ[m.cell.Wi]) end @testset "RNN" begin From 368b1f53b408cd4f1e76576c338de03e96adb53a Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Tue, 17 Sep 2019 15:49:39 +0100 Subject: [PATCH 4/5] tuple support --- src/cuda/curnn.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/cuda/curnn.jl b/src/cuda/curnn.jl index 4e2a773b..616f8327 100644 --- a/src/cuda/curnn.jl +++ b/src/cuda/curnn.jl @@ -299,6 +299,7 @@ unbroadcast(x::AbstractArray, Δ) = trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) coerce_cuda(x::Union{CuArray,Nothing}) = x +coerce_cuda(x::Tuple) = coerce_cuda.(x) coerce_cuda(x) = x .+ CuArrays.fill(0) @@ -315,7 +316,7 @@ for RNN in (CuRNN, CuGRU) @eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} reserve, (y, ho) = forwardTrain(desc(m), x, h) (ho, y), function (Δ) - dho, dy = coerce_cuda.(Δ) + dho, dy = coerce_cuda(Δ) h_ = hBatch(x, h) dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) (dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve) @@ -328,7 +329,7 @@ end @adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64} reserve, (y, ho, co) = forwardTrain(desc(m), x, h, c) ((ho, co), y), function (Δ) - dhc, dy = coerce_cuda.(Δ) + dhc, dy = coerce_cuda(Δ) dho, dco = dhc === nothing ? (nothing, nothing) : dhc h_ = hBatch(x, h) c_ = hBatch(x, c) From fc9db7ee74980d0e50a72590ca9c1804c201a31c Mon Sep 17 00:00:00 2001 From: Mike Innes Date: Tue, 17 Sep 2019 15:49:48 +0100 Subject: [PATCH 5/5] pkg up --- Manifest.toml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 17eb544e..2d1af7e8 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -58,9 +58,9 @@ version = "3.1.0" [[CUDAnative]] deps = ["Adapt", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Printf", "TimerOutputs"] -git-tree-sha1 = "0a00bef482b7c9127495c7f4a2a85e73b13b5af8" +git-tree-sha1 = "52ae1ce10ebfa686e227655c47b19add89308623" uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17" -version = "2.3.0" +version = "2.3.1" [[CodecZlib]] deps = ["BinaryProvider", "Libdl", "TranscodingStreams"] @@ -147,9 +147,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[FFTW]] deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"] -git-tree-sha1 = "e1a479d3c972f20c9a70563eec740bbfc786f515" +git-tree-sha1 = "03f8776fbdae28c20c0d1d2ae4e090cd1dfcd247" uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" -version = "0.3.0" +version = "1.0.0" [[FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays"] @@ -170,9 +170,9 @@ version = "0.10.3" [[GPUArrays]] deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "StaticArrays", "Test"] -git-tree-sha1 = "dd169c636d1d3656a9faca772f5bd7c226a61254" +git-tree-sha1 = "b5009ac44b141ded5e6f04c4db83807970f56e91" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "1.0.1" +version = "1.0.2" [[IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] @@ -198,9 +198,9 @@ version = "0.7.2" [[LLVM]] deps = ["CEnum", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "52cfea426bd248a427aace7d88eb5d45b84ea297" +git-tree-sha1 = "4a05f742837779a00bd8c9a18da6817367c4245d" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "1.2.0" +version = "1.3.0" [[LibGit2]] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" @@ -264,7 +264,7 @@ uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" version = "0.3.7" [[Pkg]] -deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" [[Printf]] @@ -388,7 +388,7 @@ version = "0.8.3" [[Zygote]] deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "9186cb0b3b59219e4aba0840614d6a9d7282012e" +git-tree-sha1 = "ce6d7142d665b1e4c71c678fa7db4da3bbc6743f" repo-rev = "master" repo-url = "https://github.com/FluxML/Zygote.jl.git" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"