This commit is contained in:
Simon Mandlik 2018-07-18 15:39:20 +02:00 committed by Mike J Innes
parent 3510c837a8
commit 0471c489e6
10 changed files with 20 additions and 16 deletions

View File

@ -25,13 +25,13 @@ end
function phones()
load()
Symbol.(first.(split.(split(read(deps("cmudict", "cmudict.phones"),String),
"\n", keep = false), "\t")))
"\n", keepempty = false), "\t")))
end
function symbols()
load()
Symbol.(split(read(deps("cmudict", "cmudict.symbols"),String),
"\n", keep = false))
"\n", keepempty = false))
end
function rawdict()
@ -42,7 +42,7 @@ end
validword(s) = isascii(s) && occursin(r"^[\w\-\.]+$", s)
cmudict() = filter((s, ps) -> validword(s), rawdict())
cmudict() = filter(p -> validword(p.first), rawdict())
alphabet() = ['A':'Z'..., '0':'9'..., '_', '-', '.']

View File

@ -39,13 +39,13 @@ adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
end
function onehot(l, labels)
i = findfirst(labels, l)
i = something(findfirst(isequal(l), labels), 0)
i > 0 || error("Value $l is not in labels")
OneHotVector(i, length(labels))
end
function onehot(l, labels, unk)
i = findfirst(labels, l)
i = something(findfirst(isequal(l), labels), 0)
i > 0 || return onehot(unk, labels)
OneHotVector(i, length(labels))
end

View File

@ -2,7 +2,7 @@ import Base: *, ==
import LinearAlgebra
using Statistics
using LinearAlgebra: Transpose, Adjoint, diagm
using LinearAlgebra: Transpose, Adjoint, diagm, diag
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
tracker::Tracked{A}
@ -94,7 +94,7 @@ Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
S = size(xs)
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
for (dest_idx, val) in enumerate(IndexCartesian(), data(Δ))
for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
# wrap around based on original size S.
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
@ -256,7 +256,7 @@ LinearAlgebra.vecnorm(x::TrackedArray, p::Real = 2) =
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
@grad mean(xs) = mean(data(xs)), Δ -> (Δ / length(xs),)
@grad mean(xs, region) = mean(data(xs), region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing)
@grad mean(xs, region) = mean(data(xs), dims = region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing)
@grad function maximum(xs, r...)
maximum(data(xs), r...), function (Δ)

View File

@ -96,7 +96,7 @@ end
@forward Grads.grads Base.setindex!, Base.haskey
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] + Δ : Δ
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ)

View File

@ -1,4 +1,4 @@
using Flux, Flux.Tracker, CuArrays, Base.Test
using Flux, Flux.Tracker, CuArrays, Test
using Flux: gpu
info("Testing Flux/GPU")

View File

@ -1,4 +1,4 @@
using Flux, CuArrays, Base.Test
using Flux, CuArrays, Test
info("Testing Flux/CUDNN")

View File

@ -1,5 +1,5 @@
using Flux.Data
using Base.Test
using Test
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args

View File

@ -1,4 +1,4 @@
using Base.Test
using Test
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
σ, binarycrossentropy, logitbinarycrossentropy

View File

@ -1,7 +1,10 @@
using Flux
using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using NNlib: conv
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm
using Statistics: mean
# using StatsBase
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
@ -110,7 +113,7 @@ end
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2))
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
end
end
@ -163,7 +166,7 @@ end
@test gradtest((x, y) -> x .* y, rand(5), rand(5))
@test gradtest(dot, rand(5), rand(5))
@test gradtest(vecnorm, rand(5))
@test gradtest(norm, rand(5))
@test gradtest(rand(5)) do x
y = x.^2

View File

@ -1,5 +1,6 @@
using Flux: throttle, initn, glorot_uniform, glorot_normal, jacobian
using StatsBase: std
using Dates
@testset "Throttle" begin
@testset "default behaviour" begin