test cleanups
This commit is contained in:
parent
9590aa63e3
commit
2f7ad895aa
@ -3,10 +3,10 @@ module Flux
|
|||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
|
||||||
using Base: tail
|
using Base: tail
|
||||||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
using Zygote, MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||||
using MacroTools: @forward
|
using MacroTools: @forward
|
||||||
@reexport using NNlib
|
@reexport using NNlib
|
||||||
using Zygote: Params, @adjoint, gradient
|
using Zygote: Params, @adjoint, gradient, forward
|
||||||
export gradient
|
export gradient
|
||||||
|
|
||||||
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,
|
||||||
|
@ -265,7 +265,7 @@ function desc(rnn)
|
|||||||
return d
|
return d
|
||||||
end
|
end
|
||||||
|
|
||||||
using Zygote: @adjoint
|
using ..Flux: @adjoint
|
||||||
|
|
||||||
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||||
result = forward(desc(m), x, h)
|
result = forward(desc(m), x, h)
|
||||||
@ -295,7 +295,7 @@ for RNN in (CuRNN, CuGRU)
|
|||||||
h_ = hBatch(x, h)
|
h_ = hBatch(x, h)
|
||||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||||
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
|
(dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
@ -309,8 +309,7 @@ end
|
|||||||
c_ = hBatch(x, c)
|
c_ = hBatch(x, c)
|
||||||
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||||
nobacksies(:RNN,
|
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
|
||||||
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
|
transpose(dWi), transpose(dWh), db)
|
||||||
transpose(dWi), transpose(dWh), db))
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
using Flux, CuArrays, Test
|
using Flux, CuArrays, Test
|
||||||
using Flux: gpu
|
using Flux: gpu
|
||||||
using Zygote
|
|
||||||
|
|
||||||
@info "Testing GPU Support"
|
@info "Testing GPU Support"
|
||||||
|
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
using Flux, CuArrays, Test
|
using Flux, CuArrays, Test
|
||||||
using Zygote
|
|
||||||
trainmode(f, x...) = forward(f, x...)[1]
|
trainmode(f, x...) = forward(f, x...)[1]
|
||||||
|
|
||||||
@testset "CUDNN BatchNorm" begin
|
@testset "CUDNN BatchNorm" begin
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
using Test
|
using Test
|
||||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||||
σ, binarycrossentropy, logitbinarycrossentropy
|
σ, binarycrossentropy, logitbinarycrossentropy
|
||||||
using Zygote
|
|
||||||
|
|
||||||
const ϵ = 1e-7
|
const ϵ = 1e-7
|
||||||
|
|
||||||
@ -56,7 +55,7 @@ const ϵ = 1e-7
|
|||||||
y = rand(T, 2)
|
y = rand(T, 2)
|
||||||
ŷ = rand(T, 2)
|
ŷ = rand(T, 2)
|
||||||
for f in (mse, crossentropy, logitcrossentropy)
|
for f in (mse, crossentropy, logitcrossentropy)
|
||||||
fwd, back = Zygote.forward(f, ŷ, y)
|
fwd, back = Flux.forward(f, ŷ, y)
|
||||||
@test fwd isa T
|
@test fwd isa T
|
||||||
@test eltype(back(one(T))[1]) == T
|
@test eltype(back(one(T))[1]) == T
|
||||||
end
|
end
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
using Flux.Optimise
|
using Flux.Optimise
|
||||||
using Flux.Optimise: runall
|
using Flux.Optimise: runall
|
||||||
using Zygote
|
using Flux: Params, gradient
|
||||||
using Zygote: Params, gradient
|
|
||||||
using Test
|
using Test
|
||||||
Zygote.@nograd sleep
|
|
||||||
|
# TODO move this to Zygote
|
||||||
|
Flux.Zygote.@nograd sleep
|
||||||
|
|
||||||
@testset "Optimise" begin
|
@testset "Optimise" begin
|
||||||
w = randn(10, 10)
|
w = randn(10, 10)
|
||||||
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
|
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
|
||||||
|
Loading…
Reference in New Issue
Block a user