test cleanups
This commit is contained in:
parent
9590aa63e3
commit
2f7ad895aa
|
@ -3,10 +3,10 @@ module Flux
|
|||
# Zero Flux Given
|
||||
|
||||
using Base: tail
|
||||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using Zygote, MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
@reexport using NNlib
|
||||
using Zygote: Params, @adjoint, gradient
|
||||
using Zygote: Params, @adjoint, gradient, forward
|
||||
export gradient
|
||||
|
||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
||||
|
|
|
@ -265,7 +265,7 @@ function desc(rnn)
|
|||
return d
|
||||
end
|
||||
|
||||
using Zygote: @adjoint
|
||||
using ..Flux: @adjoint
|
||||
|
||||
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
result = forward(desc(m), x, h)
|
||||
|
@ -295,7 +295,7 @@ for RNN in (CuRNN, CuGRU)
|
|||
h_ = hBatch(x, h)
|
||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
|
||||
(dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -309,8 +309,7 @@ end
|
|||
c_ = hBatch(x, c)
|
||||
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||
nobacksies(:RNN,
|
||||
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
|
||||
transpose(dWi), transpose(dWh), db))
|
||||
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
|
||||
transpose(dWi), transpose(dWh), db)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
using Flux, CuArrays, Test
|
||||
using Flux: gpu
|
||||
using Zygote
|
||||
|
||||
@info "Testing GPU Support"
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
using Flux, CuArrays, Test
|
||||
using Zygote
|
||||
trainmode(f, x...) = forward(f, x...)[1]
|
||||
|
||||
@testset "CUDNN BatchNorm" begin
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
using Test
|
||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||
σ, binarycrossentropy, logitbinarycrossentropy
|
||||
using Zygote
|
||||
|
||||
const ϵ = 1e-7
|
||||
|
||||
|
@ -56,7 +55,7 @@ const ϵ = 1e-7
|
|||
y = rand(T, 2)
|
||||
ŷ = rand(T, 2)
|
||||
for f in (mse, crossentropy, logitcrossentropy)
|
||||
fwd, back = Zygote.forward(f, ŷ, y)
|
||||
fwd, back = Flux.forward(f, ŷ, y)
|
||||
@test fwd isa T
|
||||
@test eltype(back(one(T))[1]) == T
|
||||
end
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
using Flux.Optimise
|
||||
using Flux.Optimise: runall
|
||||
using Zygote
|
||||
using Zygote: Params, gradient
|
||||
using Flux: Params, gradient
|
||||
using Test
|
||||
Zygote.@nograd sleep
|
||||
|
||||
# TODO move this to Zygote
|
||||
Flux.Zygote.@nograd sleep
|
||||
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
|
||||
|
|
Loading…
Reference in New Issue