commit
1750cda59e
|
@ -4,7 +4,7 @@ module Flux
|
|||
|
||||
# Zero Flux Given
|
||||
|
||||
using MacroTools, Juno, Requires, Reexport, StatsBase
|
||||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||
using MacroTools: @forward
|
||||
|
||||
export Chain, Dense, RNN, LSTM, GRU, Conv,
|
||||
|
|
|
@ -198,7 +198,7 @@ end
|
|||
|
||||
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
||||
# Same as above, any more efficient way?
|
||||
dy = dy_ isa Integer ? zeros(y) : dy_
|
||||
dy = dy_ isa Integer ? zero(y) : dy_
|
||||
yd = xDesc(y)
|
||||
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
||||
dh = similar(h)
|
||||
|
@ -229,7 +229,7 @@ function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, d
|
|||
end
|
||||
|
||||
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
|
||||
dw = zeros(rnn.params)
|
||||
dw = zero(rnn.params)
|
||||
cudnnRNNBackwardWeights(rnn, 1,
|
||||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
||||
|
|
|
@ -24,25 +24,25 @@ end
|
|||
|
||||
function phones()
|
||||
load()
|
||||
Symbol.(first.(split.(split(readstring(deps("cmudict", "cmudict.phones")),
|
||||
"\n", keep = false), "\t")))
|
||||
Symbol.(first.(split.(split(read(deps("cmudict", "cmudict.phones"),String),
|
||||
"\n", keepempty = false), "\t")))
|
||||
end
|
||||
|
||||
function symbols()
|
||||
load()
|
||||
Symbol.(split(readstring(deps("cmudict", "cmudict.symbols")),
|
||||
"\n", keep = false))
|
||||
Symbol.(split(read(deps("cmudict", "cmudict.symbols"),String),
|
||||
"\n", keepempty = false))
|
||||
end
|
||||
|
||||
function rawdict()
|
||||
load()
|
||||
Dict(String(xs[1]) => Symbol.(xs[2:end]) for xs in
|
||||
filter(!isempty, split.(split(readstring(deps("cmudict", "cmudict")), "\n"))))
|
||||
filter(!isempty, split.(split(read(deps("cmudict", "cmudict"),String), "\n"))))
|
||||
end
|
||||
|
||||
validword(s) = isascii(s) && ismatch(r"^[\w\-\.]+$", s)
|
||||
validword(s) = isascii(s) && occursin(r"^[\w\-\.]+$", s)
|
||||
|
||||
cmudict() = filter((s, ps) -> validword(s), rawdict())
|
||||
cmudict() = filter(p -> validword(p.first), rawdict())
|
||||
|
||||
alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.']
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ children(c::Chain) = c.layers
|
|||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
|
||||
|
||||
(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers)
|
||||
(c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x)
|
||||
|
||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ Conv(w::AbstractArray{T,N}, b::AbstractVector{T}, σ = identity;
|
|||
|
||||
Conv(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity; init = initn,
|
||||
stride = 1, pad = 0, dilation = 1) where N =
|
||||
Conv(param(init(k..., ch...)), param(zeros(ch[2])), σ,
|
||||
Conv(param(init(k..., ch...)), param(zero(ch[2])), σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
|
||||
@treelike Conv
|
||||
|
|
|
@ -108,7 +108,7 @@ mutable struct BatchNorm{F,V,W,N}
|
|||
end
|
||||
|
||||
BatchNorm(chs::Integer, λ = identity;
|
||||
initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) =
|
||||
initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-8, momentum = .1) =
|
||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||
|
||||
|
|
|
@ -122,7 +122,7 @@ end
|
|||
|
||||
function LSTMCell(in::Integer, out::Integer;
|
||||
init = glorot_uniform)
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zeros(out*4)),
|
||||
cell = LSTMCell(param(init(out*4, in)), param(init(out*4, out)), param(zero(out*4)),
|
||||
param(initn(out)), param(initn(out)))
|
||||
cell.b.data[gate(out, 2)] = 1
|
||||
return cell
|
||||
|
@ -170,7 +170,7 @@ end
|
|||
|
||||
GRUCell(in, out; init = glorot_uniform) =
|
||||
GRUCell(param(init(out*3, in)), param(init(out*3, out)),
|
||||
param(zeros(out*3)), param(initn(out)))
|
||||
param(zero(out*3)), param(initn(out)))
|
||||
|
||||
function (m::GRUCell)(h, x)
|
||||
b, o = m.b, size(h, 1)
|
||||
|
|
|
@ -39,13 +39,13 @@ adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
|||
end
|
||||
|
||||
function onehot(l, labels)
|
||||
i = findfirst(labels, l)
|
||||
i = something(findfirst(isequal(l), labels), 0)
|
||||
i > 0 || error("Value $l is not in labels")
|
||||
OneHotVector(i, length(labels))
|
||||
end
|
||||
|
||||
function onehot(l, labels, unk)
|
||||
i = findfirst(labels, l)
|
||||
i = something(findfirst(isequal(l), labels), 0)
|
||||
i > 0 || return onehot(unk, labels)
|
||||
OneHotVector(i, length(labels))
|
||||
end
|
||||
|
|
|
@ -9,7 +9,7 @@ struct Param{T}
|
|||
Δ::T
|
||||
end
|
||||
|
||||
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
|
||||
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zero(x))
|
||||
|
||||
include("optimisers.jl")
|
||||
include("interface.jl")
|
||||
|
|
|
@ -14,7 +14,7 @@ function descentweightdecay(p::Param, η::Real, γ::Real)
|
|||
end
|
||||
|
||||
function momentum(p::Param, ρ, η)
|
||||
v = zeros(p.x)
|
||||
v = zero(p.x)
|
||||
function ()
|
||||
@. v = ρ * v - η * p.Δ
|
||||
@. p.Δ = -v
|
||||
|
@ -23,7 +23,7 @@ end
|
|||
|
||||
# Ref. https://arxiv.org/pdf/1212.0901.pdf
|
||||
function nesterov(p::Param, ρ, η)
|
||||
v = zeros(p.x)
|
||||
v = zero(p.x)
|
||||
function ()
|
||||
d = @. ρ^2 * v - (1+ρ) * η * p.Δ
|
||||
@. v = ρ*v - η*p.Δ
|
||||
|
@ -32,7 +32,7 @@ function nesterov(p::Param, ρ, η)
|
|||
end
|
||||
|
||||
function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
||||
acc = zeros(p.x)
|
||||
acc = zero(p.x)
|
||||
function ()
|
||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||
@. p.Δ *= η / √(acc + ϵ)
|
||||
|
@ -40,7 +40,7 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
|
|||
end
|
||||
|
||||
function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
||||
acc = zeros(p.x) .+ ϵ
|
||||
acc = zero(p.x) .+ ϵ
|
||||
function ()
|
||||
@. acc += p.Δ^2
|
||||
@. p.Δ *= η / √(acc + ϵ)
|
||||
|
@ -48,8 +48,8 @@ function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
|
|||
end
|
||||
|
||||
function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
||||
acc = zeros(p.x)
|
||||
Δacc = zeros(p.x)
|
||||
acc = zero(p.x)
|
||||
Δacc = zero(p.x)
|
||||
function ()
|
||||
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
|
||||
@. p.Δ *= √(Δacc + ϵ) / √(acc + ϵ)
|
||||
|
@ -58,8 +58,8 @@ function adadelta(p::Param; ρ::Real = 0.9, ϵ::Real = 1e-8)
|
|||
end
|
||||
|
||||
function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
vt = zeros(p.x)
|
||||
mt = zero(p.x)
|
||||
vt = zero(p.x)
|
||||
β1p, β2p = β1, β2
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
|
@ -71,8 +71,8 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
|
|||
end
|
||||
|
||||
function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
ut = zeros(p.x)
|
||||
mt = zero(p.x)
|
||||
ut = zero(p.x)
|
||||
β1p = β1
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
|
@ -83,9 +83,9 @@ function adamax(p::Param; η::Real = 0.002, β1::Real = 0.9, β2::Real = 0.999,
|
|||
end
|
||||
|
||||
function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
vt = zeros(p.x) .+ ϵ
|
||||
v̂t = zeros(p.x) .+ ϵ
|
||||
mt = zero(p.x)
|
||||
vt = zero(p.x) .+ ϵ
|
||||
v̂t = zero(p.x) .+ ϵ
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
@. vt = β2 * vt + (1 - β2) * p.Δ ^ 2
|
||||
|
@ -95,8 +95,8 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999,
|
|||
end
|
||||
|
||||
function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
|
||||
mt = zeros(p.x)
|
||||
vt = zeros(p.x)
|
||||
mt = zero(p.x)
|
||||
vt = zero(p.x)
|
||||
β1p, β2p = β1, β2
|
||||
function ()
|
||||
@. mt = β1 * mt + (1 - β1) * p.Δ
|
||||
|
|
|
@ -46,14 +46,7 @@ track(f::Call, x) = Tracked{typeof(x)}(f)
|
|||
|
||||
function _forward end
|
||||
|
||||
function track(f::F, xs...) where F
|
||||
y, back = _forward(f, xs...)
|
||||
ts = map(tracker, xs)
|
||||
c = Call(back, ts)
|
||||
track(c, y)
|
||||
end
|
||||
|
||||
function track_kw(f::F, xs...; kw...) where F
|
||||
function track(f::F, xs...; kw...) where F
|
||||
y, back = _forward(f, xs...; kw...)
|
||||
track(Call(back, tracker.(xs)), y)
|
||||
end
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import Base: *, ==
|
||||
|
||||
import LinearAlgebra
|
||||
using LinearAlgebra: Transpose, Adjoint, diagm
|
||||
using Statistics
|
||||
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
||||
|
||||
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
||||
tracker::Tracked{A}
|
||||
|
@ -26,7 +27,7 @@ TrackedArray(c::Call, x::A) where A <: AbstractArray =
|
|||
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
|
||||
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x))
|
||||
|
||||
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
|
||||
|
||||
|
@ -85,7 +86,7 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
|||
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
|
||||
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||
|
||||
Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
|
||||
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
|
||||
|
||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
|
||||
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
||||
|
@ -93,7 +94,7 @@ Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
|
|||
S = size(xs)
|
||||
|
||||
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
||||
for (dest_idx, val) in enumerate(IndexCartesian(), data(Δ))
|
||||
for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
|
||||
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
||||
# wrap around based on original size S.
|
||||
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
||||
|
@ -159,10 +160,10 @@ end
|
|||
end
|
||||
end
|
||||
|
||||
Base.cat(a::TrackedArray; dims) = track_kw(cat, a, dims = dims)
|
||||
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims)
|
||||
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||
|
||||
@grad function cat(Xs...; dims)
|
||||
cat(data.(Xs)..., dims = dims), function (Δ)
|
||||
|
@ -204,32 +205,28 @@ Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
|||
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
|
||||
Base.sum(xs::TrackedArray) = track(sum, xs)
|
||||
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
||||
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
||||
|
||||
@grad sum(xs, dim...) = sum(data(xs), dim...),
|
||||
Δ -> (zero(xs) .+ Δ, map(_->nothing,dim)...)
|
||||
@grad sum(xs; dims = :) = sum(data(xs), dims = dims),
|
||||
Δ -> (zero(xs) .+ Δ, )
|
||||
|
||||
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
||||
Base.prod(xs::TrackedArray) = track(prod, xs)
|
||||
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
||||
|
||||
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
|
||||
@grad prod(xs, dim) = prod(data(xs), dim),
|
||||
@grad prod(xs, dim) = prod(data(xs), dims = dim),
|
||||
Δ -> (nobacksies(:sum,
|
||||
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
|
||||
nothing)
|
||||
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
Base.mean(xs::TrackedArray) = track(mean, xs)
|
||||
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
|
||||
Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims)
|
||||
|
||||
Base.maximum(xs::TrackedArray) = track(maximum, xs)
|
||||
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
|
||||
Base.minimum(xs::TrackedArray) = track(minimum, xs)
|
||||
Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region)
|
||||
Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims)
|
||||
Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)
|
||||
|
||||
import LinearAlgebra: dot
|
||||
|
||||
|
@ -239,34 +236,33 @@ dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
|||
|
||||
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
||||
|
||||
using StatsBase
|
||||
|
||||
# Hacks to get std working
|
||||
StatsBase.std(x::TrackedArray; mean = Base.mean(x)) =
|
||||
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
|
||||
StatsBase.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
|
||||
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
|
||||
Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)) = _std(x,mean,dims)
|
||||
_std(x::TrackedArray, mean, dims) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - 1))
|
||||
_std(x::TrackedArray, mean, ::Colon) = sqrt.(sum((x .- mean).^2) ./ (length(x) - 1))
|
||||
|
||||
LinearAlgebra.vecnorm(x::TrackedArray, p::Real = 2) =
|
||||
LinearAlgebra.norm(x::TrackedArray, p::Real = 2) =
|
||||
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
||||
|
||||
@grad mean(xs) = mean(data(xs)), Δ -> (Δ / length(xs),)
|
||||
@grad mean(xs, region) = mean(data(xs), region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing)
|
||||
@grad mean(xs; dims = :) = mean(data(xs), dims=dims), Δ -> (_backmean(xs,Δ,dims),)
|
||||
_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
|
||||
_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(data(xs),i),*,dims)
|
||||
|
||||
@grad function maximum(xs, r...)
|
||||
maximum(data(xs), r...), function (Δ)
|
||||
@grad function maximum(xs; dims = dims)
|
||||
maximum(data(xs), dims = dims), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
_, i = findmax(data(xs), r...)
|
||||
_, i = findmax(data(xs), dims = dims)
|
||||
Δ′[i] = data(Δ)
|
||||
return (nobacksies(:maximum, Δ′),map(_->nothing,r)...)
|
||||
return (nobacksies(:maximum, Δ′),)
|
||||
end
|
||||
end
|
||||
@grad function minimum(xs, r...)
|
||||
minimum(data(xs), r...), function (Δ)
|
||||
|
||||
@grad function minimum(xs; dims = dims)
|
||||
minimum(data(xs), dims = dims), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
_, i = findmin(data(xs), r...)
|
||||
_, i = findmin(data(xs), dims = dims)
|
||||
Δ′[i] = data(Δ)
|
||||
return (nobacksies(:minimum, Δ′),map(_->nothing,r)...)
|
||||
return (nobacksies(:minimum, Δ′),)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -312,9 +308,9 @@ logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
|||
|
||||
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
||||
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track_kw(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
||||
|
||||
@grad conv(x, w; kw...) =
|
||||
conv(data(x), data(w); kw...),
|
||||
|
@ -322,14 +318,14 @@ conv(x::TrackedArray, w::AbstractArray; kw...) = track_kw(conv, x, w; kw...)
|
|||
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
|
||||
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
|
||||
|
||||
maxpool(x::TrackedArray, k; kw...) = track_kw(maxpool, x, k; kw...)
|
||||
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
|
||||
|
||||
@grad function maxpool(x, k; kw...)
|
||||
y = maxpool(data(x), k; kw...)
|
||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
||||
end
|
||||
|
||||
meanpool(x::TrackedArray, k; kw...) = track_kw(meanpool, x, k; kw...)
|
||||
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
|
||||
|
||||
@grad function meanpool(x, k; kw...)
|
||||
y = meanpool(data(x), k; kw...)
|
||||
|
@ -349,7 +345,7 @@ dualify(xs::Real, ps) = Dual(xs, ps)
|
|||
|
||||
unbroadcast(x::Tuple, Δ) =
|
||||
x == size(Δ) ? Δ :
|
||||
reshape(sum(Δ, filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
|
||||
reshape(sum(Δ, dims = filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
|
||||
|
||||
unbroadcast(x::Tuple{}, Δ) = sum(Δ)
|
||||
|
||||
|
|
|
@ -96,7 +96,7 @@ end
|
|||
|
||||
@forward Grads.grads Base.setindex!, Base.haskey
|
||||
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] + Δ : Δ
|
||||
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
||||
|
||||
function back_(g::Grads, c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
function ngradient(f, xs::AbstractArray...)
|
||||
grads = zeros.(xs)
|
||||
grads = zero.(xs)
|
||||
for (x, Δ) in zip(xs, grads), i in 1:length(x)
|
||||
δ = sqrt(eps())
|
||||
tmp = x[i]
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# Arrays
|
||||
|
||||
initn(dims...) = randn(dims...)/100
|
||||
glorot_uniform(dims...) = (rand(dims...) - 0.5)*sqrt(24.0/(sum(dims)))
|
||||
glorot_normal(dims...) = (randn(dims...)*sqrt(2.0/sum(dims)))
|
||||
glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims)))
|
||||
glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims))
|
||||
|
||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||
|
||||
|
@ -145,7 +145,7 @@ function jacobian(m,x)
|
|||
y = m(xp)
|
||||
k = length(y)
|
||||
n = length(x)
|
||||
J = Matrix{eltype(x)}(n,k)
|
||||
J = Matrix{eltype(x)}(undef,n,k)
|
||||
for i = 1:k
|
||||
Flux.back!(y[i]) # Populate gradient accumulator
|
||||
J[:,i] = xp.grad
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
using Flux, Flux.Tracker, CuArrays, Base.Test
|
||||
using Flux, Flux.Tracker, CuArrays, Test
|
||||
using Flux: gpu
|
||||
|
||||
info("Testing Flux/GPU")
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
using Flux, CuArrays, Base.Test
|
||||
using Flux, CuArrays, Test
|
||||
|
||||
info("Testing Flux/CUDNN")
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Flux.Data
|
||||
using Base.Test
|
||||
using Test
|
||||
|
||||
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ using Flux: testmode!
|
|||
x = [1.,2.,3.]
|
||||
@test x == testmode!(Dropout(0.1))(x)
|
||||
@test x == Dropout(0)(x)
|
||||
@test zeros(x) == Dropout(1)(x)
|
||||
@test zero(x) == Dropout(1)(x)
|
||||
|
||||
x = rand(100)
|
||||
m = Dropout(0.9)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
using Base.Test
|
||||
using Test
|
||||
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||
σ, binarycrossentropy, logitbinarycrossentropy
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
using Flux, Base.Test
|
||||
using Flux, Test, Random
|
||||
|
||||
srand(0)
|
||||
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
using Flux
|
||||
using Flux.Tracker, Base.Test, NNlib
|
||||
using Flux.Tracker, Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
||||
using NNlib: conv
|
||||
using StatsBase
|
||||
using Printf: @sprintf
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||
using Statistics: mean, std
|
||||
# using StatsBase
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
|
@ -14,11 +17,14 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
||||
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
||||
|
||||
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
||||
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
||||
@test gradtest((w, x) -> w'*x, randn(Float64,10, 2), randn(Float64,10))
|
||||
@test gradtest((w, x) -> w*x', randn(Float64,5,5), randn(Float64,5,5))
|
||||
|
||||
@test gradtest(x -> sum(x, (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> prod(x, (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> sum(x, dims = (2, 3)), (3,4,5))
|
||||
@test gradtest(x -> sum(x, dims = 1), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x, dims = [1,2]), randn(Float64,2,3))
|
||||
@test gradtest(x -> sum(x), randn(Float64,2,3))
|
||||
@test gradtest(x -> prod(x, dims=(2, 3)), (3,4,5))
|
||||
@test gradtest(x -> prod(x), (3,4,5))
|
||||
|
||||
@test gradtest(x -> softmax(x).*(1:3), 3)
|
||||
|
@ -96,7 +102,7 @@ end
|
|||
@test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
|
||||
|
||||
@testset "promotiontest" begin
|
||||
@testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
||||
@testset for fcat in [hcat, vcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
||||
promotiontest(fcat, rand(2), rand(2), rand(2))
|
||||
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
|
||||
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
|
||||
|
@ -107,7 +113,7 @@ end
|
|||
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
|
||||
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
|
||||
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
|
||||
promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
||||
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -127,49 +133,49 @@ end
|
|||
@testset "mean" begin
|
||||
@test gradtest(mean, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> mean(x, 1), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, 2), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, 3), rand(2, 3, 4))
|
||||
@test gradtest(x -> mean(x, dims=1), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, dims=2), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, dims=3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
|
||||
@test gradtest(x -> mean(x, dims=[1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@testset "maximum" begin
|
||||
@test gradtest(maximum, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> maximum(x, 1), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, 2), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, 3), rand(2, 3, 4))
|
||||
@test gradtest(x -> maximum(x, dims=1), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, dims=2), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, dims=3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4))
|
||||
@test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@testset "minimum" begin
|
||||
@test gradtest(minimum, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> minimum(x, 1), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, 2), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, 3), rand(2, 3, 4))
|
||||
@test gradtest(x -> minimum(x, dims=1), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, dims=2), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, dims=3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4))
|
||||
@test gradtest(x -> minimum(x, dims=[1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@test gradtest(x -> std(x), rand(5,5))
|
||||
@test gradtest(x -> std(x, 1), rand(5,5))
|
||||
@test gradtest(x -> std(x, dims = 1), rand(5,5))
|
||||
|
||||
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
|
||||
@test gradtest(dot, rand(5), rand(5))
|
||||
|
||||
@test gradtest(vecnorm, rand(5))
|
||||
@test gradtest(norm, rand(5))
|
||||
|
||||
@test gradtest(rand(5)) do x
|
||||
y = x.^2
|
||||
2y + x
|
||||
end
|
||||
|
||||
@test gradtest(conv, rand(10, 3, 2), randn(2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(2, 2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 3, 2), randn(Float64,2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 3, 2), randn(Float64,2, 2, 3, 2))
|
||||
@test gradtest(conv, rand(10, 10, 10, 3, 2), randn(Float64,2, 2, 2, 3, 2))
|
||||
|
||||
@test gradtest(x -> maxpool(x, (2,2)), rand(10, 10, 3, 2))
|
||||
@test gradtest(x -> maxpool(x, (2,2,2)), rand(10, 10, 10, 3, 2))
|
||||
|
@ -213,7 +219,7 @@ end
|
|||
|
||||
@test @sprintf("%.2f", sum(param([1,2,3]))) == "6.00"
|
||||
|
||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(2,2,3,4))
|
||||
@inferred NNlib.conv(param(rand(10,10,3,2)),randn(Float64,2,2,3,4))
|
||||
|
||||
b = param(rand())
|
||||
Tracker.back!(b)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
|
||||
using StatsBase: std
|
||||
using Dates
|
||||
|
||||
@testset "Throttle" begin
|
||||
@testset "default behaviour" begin
|
||||
|
|
Loading…
Reference in New Issue