This commit is contained in:
Mike Innes 2019-09-19 18:33:33 +01:00
parent fc9db7ee74
commit b60df53ba1
7 changed files with 24 additions and 21 deletions

View File

@ -46,9 +46,9 @@ version = "0.6.2"
[[CUDAapi]] [[CUDAapi]]
deps = ["Libdl", "Logging"] deps = ["Libdl", "Logging"]
git-tree-sha1 = "9b2b4b71d6b7f946c9689bb4dea03ff92e3c7091" git-tree-sha1 = "e063efb91cfefd7e6afd92c435d01398107a500b"
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3" uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
version = "1.1.0" version = "1.2.0"
[[CUDAdrv]] [[CUDAdrv]]
deps = ["CUDAapi", "Libdl", "Printf"] deps = ["CUDAapi", "Libdl", "Printf"]
@ -147,9 +147,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[FFTW]] [[FFTW]]
deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"] deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"]
git-tree-sha1 = "03f8776fbdae28c20c0d1d2ae4e090cd1dfcd247" git-tree-sha1 = "6c5b420da0b8c12098048561b8d58f81adea506f"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "1.0.0" version = "1.0.1"
[[FillArrays]] [[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"] deps = ["LinearAlgebra", "Random", "SparseArrays"]
@ -170,9 +170,9 @@ version = "0.10.3"
[[GPUArrays]] [[GPUArrays]]
deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "StaticArrays", "Test"] deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "StaticArrays", "Test"]
git-tree-sha1 = "b5009ac44b141ded5e6f04c4db83807970f56e91" git-tree-sha1 = "77e27264276fe97a7e7fb928bf8999a145abc018"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "1.0.2" version = "1.0.3"
[[IRTools]] [[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"] deps = ["InteractiveUtils", "MacroTools", "Test"]
@ -388,7 +388,7 @@ version = "0.8.3"
[[Zygote]] [[Zygote]]
deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] deps = ["DiffRules", "FFTW", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "ce6d7142d665b1e4c71c678fa7db4da3bbc6743f" git-tree-sha1 = "38241b40ebd8748bcacad5e6c7ba3ab3cc7a15c9"
repo-rev = "master" repo-rev = "master"
repo-url = "https://github.com/FluxML/Zygote.jl.git" repo-url = "https://github.com/FluxML/Zygote.jl.git"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
@ -396,6 +396,8 @@ version = "0.3.4"
[[ZygoteRules]] [[ZygoteRules]]
deps = ["MacroTools"] deps = ["MacroTools"]
git-tree-sha1 = "def5f96ac2895fd9b48435f6b97020979ee0a4c6" git-tree-sha1 = "c4c29b30b8ff3be13d4244e78be7df2a42bc54d0"
repo-rev = "master"
repo-url = "https://github.com/FluxML/ZygoteRules.jl.git"
uuid = "700de1a5-db45-46bc-99cf-38207098b444" uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.1.0" version = "0.2.0"

View File

@ -24,6 +24,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat] [compat]
CUDAapi = "1.1" CUDAapi = "1.1"

View File

@ -6,7 +6,7 @@ using Base: tail
using Zygote, MacroTools, Juno, Reexport, Statistics, Random using Zygote, MacroTools, Juno, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
@reexport using NNlib @reexport using NNlib
using Zygote: Params, @adjoint, gradient, forward using Zygote: Params, @adjoint, gradient, pullback
export gradient export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,

View File

@ -1,5 +1,5 @@
using Flux, CuArrays, Test using Flux, CuArrays, Test
using Flux: forward using Flux: pullback
@testset "CUDNN BatchNorm" begin @testset "CUDNN BatchNorm" begin
@testset "4D Input" begin @testset "4D Input" begin
@ -8,8 +8,8 @@ using Flux: forward
cx = gpu(x) cx = gpu(x)
cm = gpu(m) cm = gpu(m)
y, back = forward((m, x) -> m(x), m, x) y, back = pullback((m, x) -> m(x), m, x)
cy, cback = forward((m, x) -> m(x), cm, cx) cy, cback = pullback((m, x) -> m(x), cm, cx)
@test cpu(cy) y @test cpu(cy) y
@ -28,8 +28,8 @@ using Flux: forward
cx = gpu(x) cx = gpu(x)
cm = gpu(m) cm = gpu(m)
y, back = forward((m, x) -> m(x), m, x) y, back = pullback((m, x) -> m(x), m, x)
cy, cback = forward((m, x) -> m(x), cm, cx) cy, cback = pullback((m, x) -> m(x), cm, cx)
@test cpu(cy) y @test cpu(cy) y

View File

@ -1,5 +1,5 @@
using Flux, CuArrays, Test using Flux, CuArrays, Test
using Flux: forward using Flux: pullback
@testset for R in [RNN, GRU, LSTM] @testset for R in [RNN, GRU, LSTM]
m = R(10, 5) |> gpu m = R(10, 5) |> gpu
@ -22,8 +22,8 @@ end
rand(10, batch_size) rand(10, batch_size)
cux = gpu(x) cux = gpu(x)
y, back = forward((r, x) -> (r(x)), rnn, x) y, back = pullback((r, x) -> (r(x)), rnn, x)
cuy, cuback = forward((r, x) -> (r(x)), curnn, cux) cuy, cuback = pullback((r, x) -> (r(x)), curnn, cux)
@test y collect(cuy) @test y collect(cuy)
@test haskey(Flux.CUDA.descs, curnn.cell) @test haskey(Flux.CUDA.descs, curnn.cell)

View File

@ -1,7 +1,7 @@
using Flux, Test, Statistics using Flux, Test, Statistics
using Zygote: forward using Zygote: pullback
trainmode(f, x...) = forward(f, x...)[1] trainmode(f, x...) = pullback(f, x...)[1]
trainmode(f) = (x...) -> trainmode(f, x...) trainmode(f) = (x...) -> trainmode(f, x...)
@testset "Dropout" begin @testset "Dropout" begin

View File

@ -55,7 +55,7 @@ const ϵ = 1e-7
y = rand(T, 2) y = rand(T, 2)
ŷ = rand(T, 2) ŷ = rand(T, 2)
for f in (mse, crossentropy, logitcrossentropy) for f in (mse, crossentropy, logitcrossentropy)
fwd, back = Flux.forward(f, , y) fwd, back = Flux.pullback(f, , y)
@test fwd isa T @test fwd isa T
@test eltype(back(one(T))[1]) == T @test eltype(back(one(T))[1]) == T
end end