Compare commits
11 Commits
master
...
zygote_old
Author | SHA1 | Date |
---|---|---|
![]() |
4c18f84ad3 | |
![]() |
bde51aa5a6 | |
![]() |
46e245b87d | |
![]() |
36055a9907 | |
![]() |
aa17cd77d0 | |
![]() |
66cc95b927 | |
![]() |
abf7f491ed | |
![]() |
7ba176f59a | |
![]() |
5514a0f53f | |
![]() |
2f256b393a | |
![]() |
e3f05eeaf3 |
|
@ -6,7 +6,7 @@ os:
|
|||
# - osx
|
||||
|
||||
julia:
|
||||
- 1.0
|
||||
- 1.1
|
||||
- nightly
|
||||
|
||||
matrix:
|
||||
|
|
|
@ -99,15 +99,21 @@ git-tree-sha1 = "4c4d727f1b7e0092134fabfab6396b8945c1ea5b"
|
|||
uuid = "f6369f11-7733-5829-9624-2563aa707210"
|
||||
version = "0.10.3"
|
||||
|
||||
[[IRTools]]
|
||||
deps = ["InteractiveUtils", "MacroTools", "Test"]
|
||||
git-tree-sha1 = "a5a47cba5f8d9a56ff683789cdd6d20ce1cb9d53"
|
||||
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
|
||||
version = "0.1.2"
|
||||
|
||||
[[InteractiveUtils]]
|
||||
deps = ["Markdown"]
|
||||
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||
|
||||
[[Juno]]
|
||||
deps = ["Base64", "Logging", "Media", "Profile", "Test"]
|
||||
git-tree-sha1 = "ce6246e19061e36cbdce954caaae717498daeed8"
|
||||
git-tree-sha1 = "dc568a3dbc4d0505d252d104bed03710a9a39441"
|
||||
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
version = "0.5.4"
|
||||
version = "0.5.5"
|
||||
|
||||
[[LibGit2]]
|
||||
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
|
||||
|
@ -248,12 +254,6 @@ version = "0.29.0"
|
|||
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
|
||||
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[[Tracker]]
|
||||
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
|
||||
git-tree-sha1 = "4eeea9f0ef9b8c7d1c5c5b1f8f68cb9b7f45d7df"
|
||||
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||
version = "0.1.0"
|
||||
|
||||
[[TranscodingStreams]]
|
||||
deps = ["Pkg", "Random", "Test"]
|
||||
git-tree-sha1 = "90f845c65c50bc57d6ffc815dbab2a4003ccf75c"
|
||||
|
@ -278,3 +278,11 @@ deps = ["BinaryProvider", "Libdl", "Printf", "Test"]
|
|||
git-tree-sha1 = "4000c633efe994b2e10b31b6d91382c4b7412dac"
|
||||
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
version = "0.8.0"
|
||||
|
||||
[[Zygote]]
|
||||
deps = ["DiffRules", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions"]
|
||||
git-tree-sha1 = "7e99e2a6c5287fe658273fdd1723726ff8a211d9"
|
||||
repo-rev = "master"
|
||||
repo-url = "https://github.com/FluxML/Zygote.jl.git"
|
||||
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
version = "0.1.0+"
|
||||
|
|
|
@ -19,5 +19,5 @@ SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
|
|||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
|
||||
ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
|
|
|
@ -5,17 +5,13 @@ module Flux
|
|||
using Base: tail
|
||||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
@reexport using NNlib
|
||||
using Zygote: Params, @adjoint, gradient
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
|
||||
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
|
||||
params, mapleaves, cpu, gpu, f32, f64
|
||||
|
||||
@reexport using NNlib
|
||||
|
||||
using Tracker
|
||||
using Tracker: data
|
||||
export Tracker, TrackedArray, TrackedVector, TrackedMatrix, param
|
||||
|
||||
include("optimise/Optimise.jl")
|
||||
using .Optimise
|
||||
using .Optimise: @epochs
|
||||
|
|
|
@ -196,33 +196,5 @@ end
|
|||
(BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
|
||||
batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)
|
||||
|
||||
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T},
|
||||
running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} =
|
||||
track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...)
|
||||
|
||||
@grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
||||
@adjoint batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
|
||||
batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)
|
||||
|
|
|
@ -221,7 +221,6 @@ end
|
|||
# Interface
|
||||
|
||||
import ..Flux: Flux, relu
|
||||
import ..Tracker: TrackedArray
|
||||
using .CuArrays.CUDAnative
|
||||
using .CuArrays: @cuindex, cudims
|
||||
|
||||
|
@ -236,10 +235,9 @@ function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
|||
return dst
|
||||
end
|
||||
|
||||
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
|
||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuGRU{T} = Flux.GRUCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuLSTM{T} = Flux.LSTMCell{<:CuParam{T,2},<:CuParam{T,1}}
|
||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
|
||||
CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
||||
|
||||
function copyparams!(m::CuRNNs, d::RNNDesc)
|
||||
|
@ -267,57 +265,48 @@ function desc(rnn)
|
|||
return d
|
||||
end
|
||||
|
||||
import Flux.Tracker
|
||||
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
|
||||
using Zygote: @adjoint
|
||||
|
||||
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
||||
|
||||
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(m, x, h, m.Wi, m.Wh, m.b) :
|
||||
forward(desc(m), x, h)
|
||||
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
result = forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
end
|
||||
|
||||
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(m, x, h, m.Wi, m.Wh, m.b) :
|
||||
forward(desc(m), x, h)
|
||||
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
result = forward(desc(m), x, h)
|
||||
return result[2], result[1]
|
||||
end
|
||||
|
||||
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||
result = istrain(m, h, x) ?
|
||||
track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
|
||||
forward(desc(m), x, h[1], h[2])
|
||||
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
|
||||
result = forward(desc(m), x, h[1], h[2])
|
||||
return (result[2], result[3]), result[1]
|
||||
end
|
||||
|
||||
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||
|
||||
@grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
|
||||
reserve, result = forwardTrain(desc(m), data(x), data(h))
|
||||
@adjoint function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
|
||||
reserve, result = forwardTrain(desc(m), x, h)
|
||||
result, function (Δ)
|
||||
y, ho = result
|
||||
dy, dho = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
h_ = hBatch(x, h)
|
||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||
nobacksies(:RNN, (dx, unbroadcast(h, dh), transpose(dWi), transpose(dWh), db))
|
||||
end
|
||||
end
|
||||
|
||||
@grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
|
||||
@adjoint function (m::CuLSTM)(x, h, c, Wi, Wh, b)
|
||||
reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
|
||||
result, function (Δ)
|
||||
y, ho = result
|
||||
dy, dho, dco = Δ
|
||||
h_ = hBatch(x, data(h))
|
||||
c_ = hBatch(x, data(c))
|
||||
h_ = hBatch(x, h)
|
||||
c_ = hBatch(x, c)
|
||||
dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], x, h_, y, reserve)
|
||||
nobacksies(:RNN,
|
||||
(dx, unbroadcast(h, dh), unbroadcast(c, dc),
|
||||
transpose(dWi), transpose(dWh), db))
|
||||
|
|
|
@ -72,7 +72,7 @@ Dense(W, b) = Dense(W, b, identity)
|
|||
|
||||
function Dense(in::Integer, out::Integer, σ = identity;
|
||||
initW = glorot_uniform, initb = zeros)
|
||||
return Dense(param(initW(out, in)), param(initb(out)), σ)
|
||||
return Dense(initW(out, in), initb(out), σ)
|
||||
end
|
||||
|
||||
@treelike Dense
|
||||
|
@ -104,7 +104,7 @@ struct Diagonal{T}
|
|||
end
|
||||
|
||||
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
||||
Diagonal(param(initα(in)), param(initβ(in)))
|
||||
Diagonal(initα(in), initβ(in))
|
||||
|
||||
@treelike Diagonal
|
||||
|
||||
|
|
|
@ -41,7 +41,7 @@ Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
|||
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
Conv(init(k..., ch...), zeros(ch[2]), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
@treelike Conv
|
||||
|
@ -91,7 +91,7 @@ ConvTranspose(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
|||
|
||||
ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1) where N =
|
||||
ConvTranspose(param(init(k..., reverse(ch)...)), param(zeros(ch[2])), σ,
|
||||
ConvTranspose(init(k..., reverse(ch)...), zeros(ch[2]), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
@treelike ConvTranspose
|
||||
|
@ -142,13 +142,13 @@ DepthwiseConv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
|||
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Integer, σ = identity; init = glorot_uniform,
|
||||
stride = 1, pad = 0) where N =
|
||||
DepthwiseConv(param(init(k..., 1, ch)), param(zeros(ch)), σ,
|
||||
DepthwiseConv(init(k..., 1, ch), zeros(ch), σ,
|
||||
stride = stride, pad = pad)
|
||||
|
||||
DepthwiseConv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = glorot_uniform,
|
||||
stride::NTuple{N,Integer} = map(_->1,k),
|
||||
pad::NTuple{N,Integer} = map(_->0,k)) where N =
|
||||
DepthwiseConv(param(init(k..., ch[2], ch[1])), param(zeros(ch[2]*ch[1])), σ,
|
||||
DepthwiseConv(init(k..., ch[2], ch[1]), zeros(ch[2] * ch[1]), σ,
|
||||
stride = stride, pad = pad)
|
||||
|
||||
@treelike DepthwiseConv
|
||||
|
|
|
@ -1,16 +1,6 @@
|
|||
"""
|
||||
testmode!(m)
|
||||
testmode!(m, false)
|
||||
istraining() = false
|
||||
|
||||
Put layers like [`Dropout`](@ref) and [`BatchNorm`](@ref) into testing mode
|
||||
(or back to training mode with `false`).
|
||||
"""
|
||||
function testmode!(m, val::Bool=true)
|
||||
prefor(x -> _testmode!(x, val), m)
|
||||
return m
|
||||
end
|
||||
|
||||
_testmode!(m, test) = nothing
|
||||
@adjoint istraining() = true, _ -> nothing
|
||||
|
||||
"""
|
||||
Dropout(p)
|
||||
|
@ -23,44 +13,38 @@ Does nothing to the input once in [`testmode!`](@ref).
|
|||
"""
|
||||
mutable struct Dropout{F}
|
||||
p::F
|
||||
active::Bool
|
||||
end
|
||||
|
||||
function Dropout(p)
|
||||
@assert 0 ≤ p ≤ 1
|
||||
Dropout{typeof(p)}(p, true)
|
||||
function Dropout(p)
|
||||
@assert 0 ≤ p ≤ 1
|
||||
new{typeof(p)}(p)
|
||||
end
|
||||
end
|
||||
|
||||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||
|
||||
function (a::Dropout)(x)
|
||||
a.active || return x
|
||||
istraining() || return x
|
||||
y = similar(x)
|
||||
rand!(y)
|
||||
y .= _dropout_kernel.(y, a.p, 1 - a.p)
|
||||
return x .* y
|
||||
end
|
||||
|
||||
_testmode!(a::Dropout, test) = (a.active = !test)
|
||||
|
||||
"""
|
||||
AlphaDropout(p)
|
||||
A dropout layer. It is used in Self-Normalizing Neural Networks.
|
||||
A dropout layer. It is used in Self-Normalizing Neural Networks.
|
||||
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
|
||||
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
|
||||
"""
|
||||
mutable struct AlphaDropout{F}
|
||||
p::F
|
||||
active::Bool
|
||||
end
|
||||
|
||||
function AlphaDropout(p)
|
||||
@assert 0 ≤ p ≤ 1
|
||||
AlphaDropout(p,true)
|
||||
function AlphaDropout(p)
|
||||
@assert 0 ≤ p ≤ 1
|
||||
new{typeof(p)}(p)
|
||||
end
|
||||
end
|
||||
|
||||
function (a::AlphaDropout)(x)
|
||||
a.active || return x
|
||||
istraining() || return x
|
||||
λ = eltype(x)(1.0507009873554804934193349852946)
|
||||
α = eltype(x)(1.6732632423543772848170429916717)
|
||||
α1 = eltype(x)(-λ*α)
|
||||
|
@ -72,8 +56,6 @@ function (a::AlphaDropout)(x)
|
|||
return x
|
||||
end
|
||||
|
||||
_testmode!(a::AlphaDropout, test) = (a.active = !test)
|
||||
|
||||
"""
|
||||
LayerNorm(h::Integer)
|
||||
|
||||
|
@ -133,13 +115,12 @@ mutable struct BatchNorm{F,V,W,N}
|
|||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Bool
|
||||
end
|
||||
|
||||
BatchNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||
BatchNorm(λ, initβ(chs), initγ(chs),
|
||||
zeros(chs), ones(chs), ϵ, momentum)
|
||||
|
||||
function (BN::BatchNorm)(x)
|
||||
size(x, ndims(x)-1) == length(BN.β) ||
|
||||
|
@ -151,7 +132,7 @@ function (BN::BatchNorm)(x)
|
|||
m = prod(size(x)[1:end-2]) * size(x)[end]
|
||||
γ = reshape(BN.γ, affine_shape...)
|
||||
β = reshape(BN.β, affine_shape...)
|
||||
if !BN.active
|
||||
if !istraining()
|
||||
μ = reshape(BN.μ, affine_shape...)
|
||||
σ² = reshape(BN.σ², affine_shape...)
|
||||
ϵ = BN.ϵ
|
||||
|
@ -160,11 +141,11 @@ function (BN::BatchNorm)(x)
|
|||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||
μ = mean(x, dims = axes)
|
||||
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
|
||||
ϵ = data(convert(T, BN.ϵ))
|
||||
ϵ = convert(T, BN.ϵ)
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, BN.momentum))
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
|
||||
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(data(σ²), :)
|
||||
mtm = convert(T, BN.momentum)
|
||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(μ, :)
|
||||
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(σ², :)
|
||||
end
|
||||
|
||||
let λ = BN.λ
|
||||
|
@ -174,12 +155,10 @@ function (BN::BatchNorm)(x)
|
|||
end
|
||||
|
||||
children(BN::BatchNorm) =
|
||||
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum, BN.active)
|
||||
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum)
|
||||
|
||||
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
|
||||
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum, BN.active)
|
||||
|
||||
_testmode!(BN::BatchNorm, test) = (BN.active = !test)
|
||||
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum)
|
||||
|
||||
function Base.show(io::IO, l::BatchNorm)
|
||||
print(io, "BatchNorm($(join(size(l.β), ", "))")
|
||||
|
@ -226,13 +205,12 @@ mutable struct InstanceNorm{F,V,W,N}
|
|||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Bool
|
||||
end
|
||||
|
||||
InstanceNorm(chs::Integer, λ = identity;
|
||||
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
|
||||
InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||
InstanceNorm(λ, initβ(chs), initγ(chs),
|
||||
zeros(chs), ones(chs), ϵ, momentum)
|
||||
|
||||
function (in::InstanceNorm)(x)
|
||||
size(x, ndims(x)-1) == length(in.β) ||
|
||||
|
@ -249,22 +227,22 @@ function (in::InstanceNorm)(x)
|
|||
m = prod(size(x)[1:end-2])
|
||||
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
|
||||
|
||||
if !in.active
|
||||
if !istraining()
|
||||
μ = expand_inst(in.μ, affine_shape)
|
||||
σ² = expand_inst(in.σ², affine_shape)
|
||||
ϵ = in.ϵ
|
||||
else
|
||||
T = eltype(x)
|
||||
|
||||
ϵ = data(convert(T, in.ϵ))
|
||||
ϵ = convert(T, in.ϵ)
|
||||
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
|
||||
μ = mean(x, dims = axes)
|
||||
σ² = mean((x .- μ) .^ 2, dims = axes)
|
||||
|
||||
# update moving mean/std
|
||||
mtm = data(convert(T, in.momentum))
|
||||
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2)
|
||||
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2)
|
||||
mtm = convert(T, in.momentum)
|
||||
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(μ, (c, bs)), dims = 2), dims=2)
|
||||
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(σ², (c, bs))), dims = 2), dims=2)
|
||||
end
|
||||
|
||||
let λ = in.λ
|
||||
|
@ -274,12 +252,10 @@ function (in::InstanceNorm)(x)
|
|||
end
|
||||
|
||||
children(in::InstanceNorm) =
|
||||
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active)
|
||||
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum)
|
||||
|
||||
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
|
||||
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active)
|
||||
|
||||
_testmode!(in::InstanceNorm, test) = (in.active = !test)
|
||||
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum)
|
||||
|
||||
function Base.show(io::IO, l::InstanceNorm)
|
||||
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
||||
|
|
|
@ -42,21 +42,6 @@ end
|
|||
|
||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||
|
||||
_truncate(x::AbstractArray) = Tracker.data(x)
|
||||
_truncate(x::Tuple) = _truncate.(x)
|
||||
|
||||
"""
|
||||
truncate!(rnn)
|
||||
|
||||
Truncates the gradient of the hidden state in recurrent layers. The value of the
|
||||
state is preserved. See also `reset!`.
|
||||
|
||||
Assuming you have a `Recur` layer `rnn`, this is roughly equivalent to
|
||||
|
||||
rnn.state = Tracker.data(rnn.state)
|
||||
"""
|
||||
truncate!(m) = prefor(x -> x isa Recur && (x.state = _truncate(x.state)), m)
|
||||
|
||||
"""
|
||||
reset!(rnn)
|
||||
|
||||
|
@ -83,8 +68,8 @@ end
|
|||
|
||||
RNNCell(in::Integer, out::Integer, σ = tanh;
|
||||
init = glorot_uniform) =
|
||||
RNNCell(σ, param(init(out, in)), param(init(out, out)),
|
||||
param(init(out)), param(zeros(out)))
|
||||
RNNCell(σ, init(out, in), init(out, out),
|
||||
init(out), zeros(out))
|
||||
|
||||
function (m::RNNCell)(h, x)
|
||||
σ, Wi, Wh, b = m.σ, m.Wi, m.Wh, m.b
|
||||
|
@ -122,8 +107,8 @@ end
|
|||
|
||||
function LSTMCell(in::Integer, out::Integer;
|
||||
init = glorot_uniform)
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(init(out*4)),
|
||||
param(zeros(out)), param(zeros(out)))
|
||||
cell = LSTMCell(init(out * 4, in), init(out * 4, out), init(out * 4),
|
||||
zeros(out), zeros(out))
|
||||
cell.b.data[gate(out, 2)] .= 1
|
||||
return cell
|
||||
end
|
||||
|
@ -168,8 +153,8 @@ mutable struct GRUCell{A,V}
|
|||
end
|
||||
|
||||
GRUCell(in, out; init = glorot_uniform) =
|
||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||
param(init(out*3)), param(zeros(out)))
|
||||
GRUCell(init(out * 3, in), init(out * 3, out),
|
||||
init(out * 3), zeros(out))
|
||||
|
||||
function (m::GRUCell)(h, x)
|
||||
b, o = m.b, size(h, 1)
|
||||
|
|
|
@ -49,8 +49,3 @@ function normalise(x::AbstractArray; dims=1)
|
|||
σ′ = std(x, dims = dims, mean = μ′, corrected=false)
|
||||
return (x .- μ′) ./ σ′
|
||||
end
|
||||
|
||||
function normalise(x::AbstractArray, dims)
|
||||
Base.depwarn("`normalise(x::AbstractArray, dims)` is deprecated, use `normalise(a, dims=dims)` instead.", :normalise)
|
||||
normalise(x, dims = dims)
|
||||
end
|
||||
|
|
|
@ -59,15 +59,6 @@ onecold(y::AbstractVector, labels = 1:length(y)) = labels[Base.argmax(y)]
|
|||
onecold(y::AbstractMatrix, labels...) =
|
||||
dropdims(mapslices(y -> onecold(y, labels...), y, dims=1), dims=1)
|
||||
|
||||
function argmax(xs...)
|
||||
Base.depwarn("`argmax(...) is deprecated, use `onecold(...)` instead.", :argmax)
|
||||
return onecold(xs...)
|
||||
end
|
||||
|
||||
# Ambiguity hack
|
||||
|
||||
a::TrackedMatrix * b::OneHotVector = invoke(*, Tuple{AbstractMatrix,OneHotVector}, a, b)
|
||||
a::TrackedMatrix * b::OneHotMatrix = invoke(*, Tuple{AbstractMatrix,OneHotMatrix}, a, b)
|
||||
|
||||
onecold(x::TrackedVector, l...) = onecold(data(x), l...)
|
||||
onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
|
||||
# TODO probably still want this as a custom adjoint Zygote
|
||||
# onecold(x::TrackedVector, l...) = onecold(data(x), l...)
|
||||
# onecold(x::TrackedMatrix, l...) = onecold(data(x), l...)
|
||||
|
|
|
@ -7,6 +7,5 @@ export train!,
|
|||
|
||||
include("optimisers.jl")
|
||||
include("train.jl")
|
||||
include("deprecations.jl")
|
||||
|
||||
end
|
||||
|
|
|
@ -1,126 +0,0 @@
|
|||
using Base: depwarn
|
||||
using Flux: Params
|
||||
|
||||
check_decay(opt, decay) = decay == 0 ? opt : Optimiser(opt, InvDecay(decay))
|
||||
|
||||
# legacy update rule
|
||||
updaterule(opt, ps) = () -> _update_params!(opt, ps)
|
||||
|
||||
function SGD(params::Union{AbstractArray, Params}, η = 0.1; decay = 0.)
|
||||
depwarn("SGD(params) is deprecated; use Descent(η::Float64) instead", :SGD)
|
||||
|
||||
ps = params
|
||||
opt = Descent(η)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function Momentum(params::Union{AbstractArray, Params}, η = 0.01; ρ = 0.9, decay = 0.)
|
||||
depwarn("Momentum(params) is deprecated; use Momentum(η::Float64) instead", :Momentum)
|
||||
|
||||
ps = params
|
||||
opt = Momentum(η, ρ)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function Nesterov(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
|
||||
depwarn("Nesterov(params) is deprecated; use Nesterov(η::Float64) instead", :Nesterov)
|
||||
|
||||
ps = params
|
||||
opt = Nesterov(η, ρ)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function RMSProp(params::Union{AbstractArray, Params}, η = 0.001; ρ = 0.9, decay = 0.)
|
||||
depwarn("RMSProp(params) is deprecated; use RMSProp(η::Float64) instead", :RMSProp)
|
||||
|
||||
ps = params
|
||||
opt = RMSProp(η, ρ)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function ADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("ADAM(params) is deprecated; use ADAM(η::Float64) instead", :ADAM)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = ADAM(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function ADAGrad(params::Union{AbstractArray, Params}, η::Float64 = 0.1; decay = 0.)
|
||||
depwarn("ADAGrad(params) is deprecated; use ADAGrad(η::Float64) instead", :ADAGrad)
|
||||
|
||||
ps = params
|
||||
opt = ADAGrad(η)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function ADADelta(params::Union{AbstractArray, Params}, ρ::Float64 = 0.9; decay = 0.)
|
||||
depwarn("ADADelta(params) is deprecated; use ADADelta(η::Float64) instead", :ADADelta)
|
||||
|
||||
ps = params
|
||||
opt = ADADelta(ρ)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function AdaMax(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("AdaMax(params) is deprecated; use AdaMax(η::Float64) instead", :AdaMax)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = AdaMax(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function AMSGrad(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("AMSGrad(params) is deprecated; use AMSGrad(η::Float64) instead", :AMSGrad)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = AMSGrad(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function NADAM(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("NADAM(params) is deprecated; use NADAM(η::Float64) instead", :NADAM)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = NADAM(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
function ADAMW(params::Union{AbstractArray, Params}, η = 0.001; β1 = 0.9, β2 = 0.999, decay = 0.)
|
||||
depwarn("ADAMW(params) is deprecated; use ADAMW(η::Float64) instead", :ADAMW)
|
||||
|
||||
ps = params
|
||||
β = (β1, β2)
|
||||
opt = ADAMW(η, β)
|
||||
opt = check_decay(opt, decay)
|
||||
decay != 0 && (opt = Optimiser(opt, WeightDecay(decay)))
|
||||
updaterule(opt, ps)
|
||||
end
|
||||
|
||||
# Old training loop
|
||||
|
||||
struct OldOptimiser
|
||||
func
|
||||
end
|
||||
|
||||
_update_params!(opt::OldOptimiser, ps) = opt.func()
|
||||
|
||||
# Train function
|
||||
function train!(loss, data, opt; cb = () -> ())
|
||||
depwarn("train!(loss, data, opt) is deprecated; use train!(loss, params, data, opt) instead", :train!)
|
||||
train!(loss, (), data, OldOptimiser(opt); cb = cb)
|
||||
end
|
|
@ -37,7 +37,7 @@ Momentum(η = 0.01, ρ = 0.9) = Momentum(η, ρ, IdDict())
|
|||
|
||||
function apply!(o::Momentum, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
@. v = ρ * v - η * Δ
|
||||
@. Δ = -v
|
||||
end
|
||||
|
@ -57,7 +57,7 @@ Nesterov(η = 0.001, ρ = 0.9) = Nesterov(η, ρ, IdDict())
|
|||
|
||||
function apply!(o::Nesterov, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
v = get!(o.velocity, x, zero(x))::typeof(data(x))
|
||||
v = get!(o.velocity, x, zero(x))::typeof(x)
|
||||
d = @. ρ^2 * v - (1+ρ) * η * Δ
|
||||
@. v = ρ*v - η*Δ
|
||||
@. Δ = -d
|
||||
|
@ -80,7 +80,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
|
|||
|
||||
function apply!(o::RMSProp, x, Δ)
|
||||
η, ρ = o.eta, o.rho
|
||||
acc = get!(o.acc, x, zero(x))::typeof(data(x))
|
||||
acc = get!(o.acc, x, zero(x))::typeof(x)
|
||||
@. acc = ρ * acc + (1 - ρ) * Δ^2
|
||||
@. Δ *= η / (√acc + ϵ)
|
||||
end
|
||||
|
@ -147,7 +147,7 @@ ADAGrad(η = 0.1) = ADAGrad(η, IdDict())
|
|||
|
||||
function apply!(o::ADAGrad, x, Δ)
|
||||
η = o.eta
|
||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(data(x))
|
||||
acc = get!(o.acc, x, fill(ϵ, size(x)))::typeof(x)
|
||||
@. acc += Δ^2
|
||||
@. Δ *= η / (√acc + ϵ)
|
||||
end
|
||||
|
@ -323,5 +323,5 @@ WeightDecay() = WeightDecay(0)
|
|||
|
||||
function apply!(o::WeightDecay, x, Δ)
|
||||
wd = o.wd
|
||||
@. Δ += wd * data(x)
|
||||
@. Δ += wd * x
|
||||
end
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
using Juno
|
||||
import Flux.Tracker: Params, gradient, data, update!
|
||||
import Base.depwarn
|
||||
import Zygote: Params, gradient
|
||||
|
||||
function update!(x::AbstractArray, x̄)
|
||||
x .+= x̄
|
||||
return x
|
||||
end
|
||||
|
||||
function update!(opt, x, x̄)
|
||||
update!(x, -apply!(opt, x, data(x̄)))
|
||||
update!(x, -apply!(opt, x, x̄))
|
||||
end
|
||||
|
||||
function update!(opt, xs::Params, gs)
|
||||
|
@ -12,15 +16,6 @@ function update!(opt, xs::Params, gs)
|
|||
end
|
||||
end
|
||||
|
||||
# Added as an internal API but everyone started using it.
|
||||
function _update_params!(opt, xs)
|
||||
depwarn("`_update_params!` is deprecated, use `update!` instead.", :stop)
|
||||
for x in xs
|
||||
update!(opt, x, Tracker.grad(x))
|
||||
x.tracker.grad = Tracker.zero_grad!(x.tracker.grad)
|
||||
end
|
||||
end
|
||||
|
||||
# Callback niceties
|
||||
call(f, xs...) = f(xs...)
|
||||
runall(f) = f
|
||||
|
@ -72,10 +67,6 @@ function train!(loss, ps, data, opt; cb = () -> ())
|
|||
loss(d...)
|
||||
end
|
||||
update!(opt, ps, gs)
|
||||
if cb() == :stop
|
||||
depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop)
|
||||
break
|
||||
end
|
||||
catch ex
|
||||
if ex isa StopException
|
||||
break
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import Adapt: adapt, adapt_storage
|
||||
import .Tracker: IdSet
|
||||
import Zygote: IdSet
|
||||
|
||||
children(x) = ()
|
||||
mapchildren(f, x) = x
|
||||
|
@ -39,7 +39,7 @@ end
|
|||
function params(m)
|
||||
ps = Params()
|
||||
prefor(p ->
|
||||
Tracker.istracked(p) && Tracker.isleaf(p) &&
|
||||
p isa AbstractArray{<:Real} &&
|
||||
!any(p′ -> p′ === p, ps) && push!(ps, p),
|
||||
m)
|
||||
return ps
|
||||
|
@ -51,7 +51,7 @@ function loadparams!(m, xs)
|
|||
for (p, x) in zip(params(m), xs)
|
||||
size(p) == size(x) ||
|
||||
error("Expected param size $(size(p)), got $(size(x))")
|
||||
copyto!(data(p), data(x))
|
||||
copyto!(p, x)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -80,8 +80,6 @@ f64(m) = paramtype(Float64, m)
|
|||
|
||||
function mapparams(f, m)
|
||||
mapleaves(m) do x
|
||||
Tracker.istracked(x) ? param(f(Tracker.data(x))) :
|
||||
x isa Union{AbstractArray,Number} ? f(x) :
|
||||
x
|
||||
x isa Union{AbstractArray,Number} ? f(x) : x
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
using Flux, Flux.Tracker, CuArrays, Test
|
||||
using Flux, CuArrays, Test
|
||||
using Flux: gpu
|
||||
|
||||
@info "Testing GPU Support"
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
using Flux, Flux.Tracker, CuArrays, Test
|
||||
using Flux.Tracker: TrackedArray, data
|
||||
using Flux, CuArrays, Test
|
||||
|
||||
@testset "CUDNN BatchNorm" begin
|
||||
@testset "4D Input" begin
|
||||
|
|
|
@ -1,202 +1,201 @@
|
|||
using Flux: testmode!
|
||||
using Flux.Tracker: data
|
||||
using Flux, Test
|
||||
using Zygote: forward
|
||||
|
||||
trainmode(f, x...) = forward(f, x...)[1]
|
||||
|
||||
@testset "Dropout" begin
|
||||
x = [1.,2.,3.]
|
||||
@test x == testmode!(Dropout(0.1))(x)
|
||||
@test x == Dropout(0)(x)
|
||||
@test zero(x) == Dropout(1)(x)
|
||||
@test x == Dropout(0.1)(x)
|
||||
@test x == trainmode(Dropout(0), (x))
|
||||
@test zero(x) == trainmode(Dropout(1), (x))
|
||||
|
||||
x = rand(100)
|
||||
m = Dropout(0.9)
|
||||
y = m(x)
|
||||
y = trainmode(m, x)
|
||||
@test count(a->a==0, y) > 50
|
||||
testmode!(m)
|
||||
y = m(x)
|
||||
@test count(a->a==0, y) == 0
|
||||
testmode!(m, false)
|
||||
y = m(x)
|
||||
y = trainmode(m, x)
|
||||
@test count(a->a==0, y) > 50
|
||||
|
||||
x = rand(100)
|
||||
x = rand(Float32, 100)
|
||||
m = Chain(Dense(100,100),
|
||||
Dropout(0.9))
|
||||
y = m(x)
|
||||
y = trainmode(m, x)
|
||||
@test count(a->a == 0, y) > 50
|
||||
testmode!(m)
|
||||
y = m(x)
|
||||
@test count(a->a == 0, y) == 0
|
||||
end
|
||||
|
||||
@testset "BatchNorm" begin
|
||||
let m = BatchNorm(2), x = param([1 3 5;
|
||||
2 4 6])
|
||||
|
||||
@test m.β.data == [0, 0] # initβ(2)
|
||||
@test m.γ.data == [1, 1] # initγ(2)
|
||||
# initial m.σ is 1
|
||||
# initial m.μ is 0
|
||||
@test m.active
|
||||
|
||||
# @test m(x).data ≈ [-1 -1; 0 0; 1 1]'
|
||||
m(x)
|
||||
|
||||
# julia> x
|
||||
# 2×3 Array{Float64,2}:
|
||||
# 1.0 3.0 5.0
|
||||
# 2.0 4.0 6.0
|
||||
#
|
||||
# μ of batch will be
|
||||
# (1. + 3. + 5.) / 3 = 3
|
||||
# (2. + 4. + 6.) / 3 = 4
|
||||
#
|
||||
# ∴ update rule with momentum:
|
||||
# .1 * 3 + 0 = .3
|
||||
# .1 * 4 + 0 = .4
|
||||
@test m.μ ≈ reshape([0.3, 0.4], 2, 1)
|
||||
|
||||
# julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
# 2×1 Array{Float64,2}:
|
||||
# 1.3
|
||||
# 1.3
|
||||
@test m.σ² ≈ .1 .* var(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
x′ = m(x).data
|
||||
@test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5)
|
||||
end
|
||||
|
||||
# with activation function
|
||||
let m = BatchNorm(2, sigmoid), x = param([1 3 5;
|
||||
2 4 6])
|
||||
@test m.active
|
||||
m(x)
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
y = m(x).data
|
||||
@test isapprox(y, data(sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ))), atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1))
|
||||
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = BatchNorm(2), x = param(reshape(1:24, 2, 2, 3, 2, 1))
|
||||
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
||||
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1);
|
||||
m(x)
|
||||
@test (@allocated m(x)) < 100_000_000
|
||||
end
|
||||
end
|
||||
# @testset "BatchNorm" begin
|
||||
# let m = BatchNorm(2), x = [1 3 5;
|
||||
# 2 4 6]
|
||||
#
|
||||
# @test m.β.data == [0, 0] # initβ(2)
|
||||
# @test m.γ.data == [1, 1] # initγ(2)
|
||||
# # initial m.σ is 1
|
||||
# # initial m.μ is 0
|
||||
# @test m.active
|
||||
#
|
||||
# # @test m(x).data ≈ [-1 -1; 0 0; 1 1]'
|
||||
# m(x)
|
||||
#
|
||||
# # julia> x
|
||||
# # 2×3 Array{Float64,2}:
|
||||
# # 1.0 3.0 5.0
|
||||
# # 2.0 4.0 6.0
|
||||
# #
|
||||
# # μ of batch will be
|
||||
# # (1. + 3. + 5.) / 3 = 3
|
||||
# # (2. + 4. + 6.) / 3 = 4
|
||||
# #
|
||||
# # ∴ update rule with momentum:
|
||||
# # .1 * 3 + 0 = .3
|
||||
# # .1 * 4 + 0 = .4
|
||||
# @test m.μ ≈ reshape([0.3, 0.4], 2, 1)
|
||||
#
|
||||
# # julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
# # 2×1 Array{Float64,2}:
|
||||
# # 1.3
|
||||
# # 1.3
|
||||
# @test m.σ² ≈ .1 .* var(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
#
|
||||
# testmode!(m)
|
||||
# @test !m.active
|
||||
#
|
||||
# x′ = m(x).data
|
||||
# @test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5)
|
||||
# end
|
||||
#
|
||||
# # with activation function
|
||||
# let m = BatchNorm(2, sigmoid), x = param([1 3 5;
|
||||
# 2 4 6])
|
||||
# @test m.active
|
||||
# m(x)
|
||||
#
|
||||
# testmode!(m)
|
||||
# @test !m.active
|
||||
#
|
||||
# y = m(x).data
|
||||
# @test isapprox(y, data(sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ))), atol = 1.0e-7)
|
||||
# end
|
||||
#
|
||||
# let m = BatchNorm(2), x = param(reshape(1:6, 3, 2, 1))
|
||||
# y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
||||
# y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
||||
# @test m(x) == y
|
||||
# end
|
||||
#
|
||||
# let m = BatchNorm(2), x = param(reshape(1:12, 2, 3, 2, 1))
|
||||
# y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
||||
# y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
||||
# @test m(x) == y
|
||||
# end
|
||||
#
|
||||
# let m = BatchNorm(2), x = param(reshape(1:24, 2, 2, 3, 2, 1))
|
||||
# y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
||||
# y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
||||
# @test m(x) == y
|
||||
# end
|
||||
#
|
||||
# let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1);
|
||||
# m(x)
|
||||
# @test (@allocated m(x)) < 100_000_000
|
||||
# end
|
||||
# end
|
||||
|
||||
|
||||
@testset "InstanceNorm" begin
|
||||
# helper functions
|
||||
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||
# begin tests
|
||||
let m = InstanceNorm(2), sizes = (3, 2, 2),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
|
||||
@test m.β.data == [0, 0] # initβ(2)
|
||||
@test m.γ.data == [1, 1] # initγ(2)
|
||||
|
||||
@test m.active
|
||||
|
||||
m(x)
|
||||
|
||||
#julia> x
|
||||
#[:, :, 1] =
|
||||
# 1.0 4.0
|
||||
# 2.0 5.0
|
||||
# 3.0 6.0
|
||||
#
|
||||
#[:, :, 2] =
|
||||
# 7.0 10.0
|
||||
# 8.0 11.0
|
||||
# 9.0 12.0
|
||||
#
|
||||
# μ will be
|
||||
# (1. + 2. + 3.) / 3 = 2.
|
||||
# (4. + 5. + 6.) / 3 = 5.
|
||||
#
|
||||
# (7. + 8. + 9.) / 3 = 8.
|
||||
# (10. + 11. + 12.) / 3 = 11.
|
||||
#
|
||||
# ∴ update rule with momentum:
|
||||
# (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5
|
||||
# (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8
|
||||
@test m.μ ≈ [0.5, 0.8]
|
||||
# momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq
|
||||
# julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||
# 2-element Array{Float64,1}:
|
||||
# 1.
|
||||
# 1.
|
||||
@test m.σ² ≈ reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
x′ = m(x).data
|
||||
@test isapprox(x′[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5)
|
||||
end
|
||||
# with activation function
|
||||
let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
|
||||
affine_shape = collect(sizes)
|
||||
affine_shape[1] = 1
|
||||
|
||||
@test m.active
|
||||
m(x)
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
||||
y = m(x).data
|
||||
@test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7)
|
||||
end
|
||||
|
||||
let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
y = reshape(m(y), sizes...)
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
# check that μ, σ², and the output are the correct size for higher rank tensors
|
||||
let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
y = m(x)
|
||||
@test size(m.μ) == (sizes[end - 1], )
|
||||
@test size(m.σ²) == (sizes[end - 1], )
|
||||
@test size(y) == sizes
|
||||
end
|
||||
|
||||
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
||||
let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6),
|
||||
x = param(reshape(collect(1:prod(sizes)), sizes))
|
||||
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
||||
end
|
||||
|
||||
let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1);
|
||||
m(x)
|
||||
@test (@allocated m(x)) < 100_000_000
|
||||
end
|
||||
|
||||
end
|
||||
# @testset "InstanceNorm" begin
|
||||
# # helper functions
|
||||
# expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||
# # begin tests
|
||||
# let m = InstanceNorm(2), sizes = (3, 2, 2),
|
||||
# x = reshape(collect(1:prod(sizes)), sizes)
|
||||
#
|
||||
# @test m.β.data == [0, 0] # initβ(2)
|
||||
# @test m.γ.data == [1, 1] # initγ(2)
|
||||
#
|
||||
# @test m.active
|
||||
#
|
||||
# m(x)
|
||||
#
|
||||
# #julia> x
|
||||
# #[:, :, 1] =
|
||||
# # 1.0 4.0
|
||||
# # 2.0 5.0
|
||||
# # 3.0 6.0
|
||||
# #
|
||||
# #[:, :, 2] =
|
||||
# # 7.0 10.0
|
||||
# # 8.0 11.0
|
||||
# # 9.0 12.0
|
||||
# #
|
||||
# # μ will be
|
||||
# # (1. + 2. + 3.) / 3 = 2.
|
||||
# # (4. + 5. + 6.) / 3 = 5.
|
||||
# #
|
||||
# # (7. + 8. + 9.) / 3 = 8.
|
||||
# # (10. + 11. + 12.) / 3 = 11.
|
||||
# #
|
||||
# # ∴ update rule with momentum:
|
||||
# # (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5
|
||||
# # (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8
|
||||
# @test m.μ ≈ [0.5, 0.8]
|
||||
# # momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq
|
||||
# # julia> reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||
# # 2-element Array{Float64,1}:
|
||||
# # 1.
|
||||
# # 1.
|
||||
# @test m.σ² ≈ reshape(mean(.1 .* var(x.data, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1.
|
||||
#
|
||||
# testmode!(m)
|
||||
# @test !m.active
|
||||
#
|
||||
# x′ = m(x).data
|
||||
# @test isapprox(x′[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5)
|
||||
# end
|
||||
# # with activation function
|
||||
# let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2),
|
||||
# x = reshape(collect(1:prod(sizes)), sizes)
|
||||
#
|
||||
# affine_shape = collect(sizes)
|
||||
# affine_shape[1] = 1
|
||||
#
|
||||
# @test m.active
|
||||
# m(x)
|
||||
#
|
||||
# testmode!(m)
|
||||
# @test !m.active
|
||||
#
|
||||
# y = m(x).data
|
||||
# @test isapprox(y, data(sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ))), atol = 1.0e-7)
|
||||
# end
|
||||
#
|
||||
# let m = InstanceNorm(2), sizes = (2, 4, 1, 2, 3),
|
||||
# x = reshape(collect(1:prod(sizes)), sizes)
|
||||
# y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||
# y = reshape(m(y), sizes...)
|
||||
# @test m(x) == y
|
||||
# end
|
||||
#
|
||||
# # check that μ, σ², and the output are the correct size for higher rank tensors
|
||||
# let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6),
|
||||
# x = reshape(collect(1:prod(sizes)), sizes)
|
||||
# y = m(x)
|
||||
# @test size(m.μ) == (sizes[end - 1], )
|
||||
# @test size(m.σ²) == (sizes[end - 1], )
|
||||
# @test size(y) == sizes
|
||||
# end
|
||||
#
|
||||
# # show that instance norm is equal to batch norm when channel and batch dims are squashed
|
||||
# let m_inorm = InstanceNorm(2), m_bnorm = BatchNorm(12), sizes = (5, 5, 3, 4, 2, 6),
|
||||
# x = reshape(collect(1:prod(sizes)), sizes)
|
||||
# @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
||||
# end
|
||||
#
|
||||
# let m = InstanceNorm(32), x = randn(Float32, 416, 416, 32, 1);
|
||||
# m(x)
|
||||
# @test (@allocated m(x)) < 100_000_000
|
||||
# end
|
||||
#
|
||||
# end
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
using Test
|
||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||
σ, binarycrossentropy, logitbinarycrossentropy
|
||||
using Zygote
|
||||
|
||||
const ϵ = 1e-7
|
||||
|
||||
|
@ -55,9 +56,9 @@ const ϵ = 1e-7
|
|||
y = rand(T, 2)
|
||||
ŷ = rand(T, 2)
|
||||
for f in (mse, crossentropy, logitcrossentropy)
|
||||
fwd, back = Flux.Tracker.forward(mse, ŷ, y)
|
||||
@test typeof(fwd) == Flux.Tracker.TrackedReal{T}
|
||||
@test eltype(back(one(T))[1]) == Flux.Tracker.TrackedReal{T}
|
||||
fwd, back = Zygote.forward(mse, ŷ, y)
|
||||
@test fwd isa T
|
||||
@test eltype(back(one(T))[1]) == T
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
100
test/optimise.jl
100
test/optimise.jl
|
@ -1,55 +1,55 @@
|
|||
using Flux.Optimise
|
||||
using Flux.Optimise: runall
|
||||
using Flux.Tracker
|
||||
using Zygote: Params, gradient
|
||||
using Test
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
|
||||
NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
|
||||
Momentum()]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
for t = 1: 10^5
|
||||
θ = Params([w′])
|
||||
θ̄ = gradient(() -> loss(rand(10)), θ)
|
||||
Optimise.update!(opt, θ, θ̄)
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
end
|
||||
end
|
||||
# @testset "Optimise" begin
|
||||
# w = randn(10, 10)
|
||||
# @testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
|
||||
# NADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
|
||||
# Momentum()]
|
||||
# w′ = randn(10, 10)
|
||||
# loss(x) = Flux.mse(w*x, w′*x)
|
||||
# for t = 1: 10^5
|
||||
# θ = Params([w′])
|
||||
# θ̄ = gradient(() -> loss(rand(10)), θ)
|
||||
# Optimise.update!(opt, θ, θ̄)
|
||||
# end
|
||||
# @test Flux.mse(w, w′) < 0.01
|
||||
# end
|
||||
# end
|
||||
|
||||
@testset "Optimiser" begin
|
||||
w = randn(10, 10)
|
||||
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Optimiser(Opt(), ADAM(0.001))
|
||||
for t = 1:10^5
|
||||
l = loss(rand(10))
|
||||
back!(l)
|
||||
delta = Optimise.apply!(opt, w′.data, w′.grad)
|
||||
w′.data .-= delta
|
||||
end
|
||||
@test Flux.mse(w, w′) < 0.01
|
||||
end
|
||||
end
|
||||
# @testset "Optimiser" begin
|
||||
# w = randn(10, 10)
|
||||
# @testset for Opt in [InvDecay, WeightDecay, ExpDecay]
|
||||
# w′ = param(randn(10, 10))
|
||||
# loss(x) = Flux.mse(w*x, w′*x)
|
||||
# opt = Optimiser(Opt(), ADAM(0.001))
|
||||
# for t = 1:10^5
|
||||
# l = loss(rand(10))
|
||||
# back!(l)
|
||||
# delta = Optimise.apply!(opt, w′.data, w′.grad)
|
||||
# w′.data .-= delta
|
||||
# end
|
||||
# @test Flux.mse(w, w′) < 0.01
|
||||
# end
|
||||
# end
|
||||
|
||||
@testset "Training Loop" begin
|
||||
i = 0
|
||||
l = param(1)
|
||||
|
||||
Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||
(),
|
||||
Iterators.repeated((), 100),
|
||||
Descent(),
|
||||
cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
|
||||
|
||||
@test 3 < i < 50
|
||||
|
||||
# Test multiple callbacks
|
||||
x = 0
|
||||
fs = [() -> (), () -> x = 1]
|
||||
cbs = runall(fs)
|
||||
cbs()
|
||||
@test x == 1
|
||||
end
|
||||
# @testset "Training Loop" begin
|
||||
# i = 0
|
||||
# l = 1
|
||||
#
|
||||
# Flux.train!(() -> (sleep(0.1); i += 1; l),
|
||||
# (),
|
||||
# Iterators.repeated((), 100),
|
||||
# Descent(),
|
||||
# cb = Flux.throttle(() -> (i > 3 && Flux.stop()), 1))
|
||||
#
|
||||
# @test 3 < i < 50
|
||||
#
|
||||
# # Test multiple callbacks
|
||||
# x = 0
|
||||
# fs = [() -> (), () -> x = 1]
|
||||
# cbs = runall(fs)
|
||||
# cbs()
|
||||
# @test x == 1
|
||||
# end
|
||||
|
|
|
@ -1,5 +1,23 @@
|
|||
using Flux, Test
|
||||
using Tracker: gradcheck
|
||||
|
||||
function ngradient(f, xs::AbstractArray...)
|
||||
grads = zero.(xs)
|
||||
for (x, Δ) in zip(xs, grads), i in 1:length(x)
|
||||
δ = sqrt(eps())
|
||||
tmp = x[i]
|
||||
x[i] = tmp - δ/2
|
||||
y1 = f(xs...)
|
||||
x[i] = tmp + δ/2
|
||||
y2 = f(xs...)
|
||||
x[i] = tmp
|
||||
Δ[i] = (y2-y1)/δ
|
||||
end
|
||||
return grads
|
||||
end
|
||||
|
||||
gradcheck(f, xs...) =
|
||||
all(isapprox.(ngradient(f, xs...),
|
||||
gradient(f, xs...), rtol = 1e-5, atol = 1e-5))
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
||||
|
@ -9,7 +27,7 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
|||
@test gradtest(Flux.mse, rand(5,5), rand(5, 5))
|
||||
@test gradtest(Flux.crossentropy, rand(5,5), rand(5, 5))
|
||||
|
||||
@test gradtest(x -> Flux.normalise(x), rand(4,3))
|
||||
@test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4))
|
||||
# @test gradtest(x -> Flux.normalise(x), rand(4,3))
|
||||
# @test gradtest(x -> Flux.normalise(x, dims = 2), rand(3,4))
|
||||
|
||||
end
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Flux
|
||||
using Flux: throttle, jacobian, glorot_uniform, glorot_normal, stack, unstack
|
||||
using Flux: throttle, glorot_uniform, glorot_normal, stack, unstack
|
||||
using StatsBase: std
|
||||
using Random
|
||||
using Test
|
||||
|
@ -52,15 +52,6 @@ using Test
|
|||
end
|
||||
end
|
||||
|
||||
@testset "Jacobian" begin
|
||||
A = param(randn(2,2))
|
||||
x = randn(2)
|
||||
m(x) = A*x
|
||||
y = m(x)
|
||||
J = jacobian(m,x)
|
||||
@test J ≈ A.data
|
||||
end
|
||||
|
||||
@testset "Initialization" begin
|
||||
# Set random seed so that these tests don't fail randomly
|
||||
Random.seed!(0)
|
||||
|
@ -96,12 +87,11 @@ end
|
|||
@testset "Precision" begin
|
||||
m = Chain(Dense(10, 5, relu), Dense(5, 2))
|
||||
x = rand(10)
|
||||
@test eltype(m[1].W.data) == Float32
|
||||
@test eltype(m(x).data) == Float32
|
||||
@test eltype(f64(m)(x).data) == Float64
|
||||
@test eltype(f64(m)[1].W.data) == Float64
|
||||
@test eltype(f32(f64(m))[1].W.data) == Float32
|
||||
@test Tracker.isleaf(f32(f64(m))[1].W)
|
||||
@test eltype(m[1].W) == Float32
|
||||
@test eltype(m(x)) == Float32
|
||||
@test eltype(f64(m)(x)) == Float64
|
||||
@test eltype(f64(m)[1].W) == Float64
|
||||
@test eltype(f32(f64(m))[1].W) == Float32
|
||||
end
|
||||
|
||||
@testset "Stacking" begin
|
||||
|
|
Loading…
Reference in New Issue