test cleanups

This commit is contained in:
Mike Innes 2019-08-19 15:22:50 +01:00
parent 9590aa63e3
commit 2f7ad895aa
6 changed files with 12 additions and 14 deletions

View File

@ -3,10 +3,10 @@ module Flux
# Zero Flux Given # Zero Flux Given
using Base: tail using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random using Zygote, MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
@reexport using NNlib @reexport using NNlib
using Zygote: Params, @adjoint, gradient using Zygote: Params, @adjoint, gradient, forward
export gradient export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,

View File

@ -265,7 +265,7 @@ function desc(rnn)
return d return d
end end
using Zygote: @adjoint using ..Flux: @adjoint
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64} function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
result = forward(desc(m), x, h) result = forward(desc(m), x, h)
@ -295,7 +295,7 @@ for RNN in (CuRNN, CuGRU)
h_ = hBatch(x, h) h_ = hBatch(x, h)
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve) dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve) (dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)) (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)
end end
end end
end end
@ -309,8 +309,7 @@ end
c_ = hBatch(x, c) c_ = hBatch(x, c)
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve) dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve) (dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(h, dh), unbroadcast(c, dc),
(dx, unbroadcast(h, dh), unbroadcast(c, dc), transpose(dWi), transpose(dWh), db)
transpose(dWi), transpose(dWh), db))
end end
end end

View File

@ -1,6 +1,5 @@
using Flux, CuArrays, Test using Flux, CuArrays, Test
using Flux: gpu using Flux: gpu
using Zygote
@info "Testing GPU Support" @info "Testing GPU Support"

View File

@ -1,5 +1,4 @@
using Flux, CuArrays, Test using Flux, CuArrays, Test
using Zygote
trainmode(f, x...) = forward(f, x...)[1] trainmode(f, x...) = forward(f, x...)[1]
@testset "CUDNN BatchNorm" begin @testset "CUDNN BatchNorm" begin

View File

@ -1,7 +1,6 @@
using Test using Test
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
σ, binarycrossentropy, logitbinarycrossentropy σ, binarycrossentropy, logitbinarycrossentropy
using Zygote
const ϵ = 1e-7 const ϵ = 1e-7
@ -56,7 +55,7 @@ const ϵ = 1e-7
y = rand(T, 2) y = rand(T, 2)
ŷ = rand(T, 2) ŷ = rand(T, 2)
for f in (mse, crossentropy, logitcrossentropy) for f in (mse, crossentropy, logitcrossentropy)
fwd, back = Zygote.forward(f, , y) fwd, back = Flux.forward(f, , y)
@test fwd isa T @test fwd isa T
@test eltype(back(one(T))[1]) == T @test eltype(back(one(T))[1]) == T
end end

View File

@ -1,9 +1,11 @@
using Flux.Optimise using Flux.Optimise
using Flux.Optimise: runall using Flux.Optimise: runall
using Zygote using Flux: Params, gradient
using Zygote: Params, gradient
using Test using Test
Zygote.@nograd sleep
# TODO move this to Zygote
Flux.Zygote.@nograd sleep
@testset "Optimise" begin @testset "Optimise" begin
w = randn(10, 10) w = randn(10, 10)
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(), @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),