Compare commits
No commits in common. "master" and "staging" have entirely different histories.
@ -14,9 +14,9 @@ version = "0.3.3"
|
||||
|
||||
[[Adapt]]
|
||||
deps = ["LinearAlgebra"]
|
||||
git-tree-sha1 = "fd04049c7dd78cfef0b06cdc1f0f181467655712"
|
||||
git-tree-sha1 = "c88cfc7f9c1f9f8633cddf0b56e86302b70f64c5"
|
||||
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
|
||||
version = "1.1.0"
|
||||
version = "1.0.1"
|
||||
|
||||
[[ArrayLayouts]]
|
||||
deps = ["FillArrays", "LinearAlgebra"]
|
||||
@ -29,9 +29,9 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
|
||||
|
||||
[[BinaryProvider]]
|
||||
deps = ["Libdl", "Logging", "SHA"]
|
||||
git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058"
|
||||
git-tree-sha1 = "428e9106b1ff27593cbd979afac9b45b82372b8c"
|
||||
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
|
||||
version = "0.5.10"
|
||||
version = "0.5.9"
|
||||
|
||||
[[CEnum]]
|
||||
git-tree-sha1 = "1b77a77c3b28e0b3f413f7567c9bb8dd9bdccd14"
|
||||
@ -76,9 +76,9 @@ version = "0.10.3"
|
||||
|
||||
[[Colors]]
|
||||
deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Reexport"]
|
||||
git-tree-sha1 = "1e9bba7984e78aa8cdeea7f9f7cc984ad4e4b1c7"
|
||||
git-tree-sha1 = "2fdeb981ebcf52cd800ddb6a0aa5eac34153552d"
|
||||
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||
version = "0.12.2"
|
||||
version = "0.12.0"
|
||||
|
||||
[[CommonSubexpressions]]
|
||||
deps = ["Test"]
|
||||
@ -93,16 +93,16 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
|
||||
version = "0.3.3+0"
|
||||
|
||||
[[Cthulhu]]
|
||||
deps = ["CodeTracking", "InteractiveUtils", "REPL", "UUIDs", "Unicode"]
|
||||
git-tree-sha1 = "f3643e78353199d3097821e806348bd83f364155"
|
||||
deps = ["CodeTracking", "InteractiveUtils", "REPL", "Unicode"]
|
||||
git-tree-sha1 = "a4849ec61df9659423cc63b298ed895904ee9743"
|
||||
uuid = "f68482b8-f384-11e8-15f7-abe071a5a75f"
|
||||
version = "1.1.1"
|
||||
version = "1.0.2"
|
||||
|
||||
[[CuArrays]]
|
||||
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "Requires", "SparseArrays", "Statistics", "TimerOutputs"]
|
||||
git-tree-sha1 = "1582b74d2322df7dd94549d4ac9d095e0f20e884"
|
||||
git-tree-sha1 = "870a4ac61e99c36f42d15e496fd290c841541d90"
|
||||
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||
version = "2.2.1"
|
||||
version = "2.2.0"
|
||||
|
||||
[[DataAPI]]
|
||||
git-tree-sha1 = "176e23402d80e7743fc26c19c681bfb11246af32"
|
||||
@ -111,9 +111,9 @@ version = "1.3.0"
|
||||
|
||||
[[DataStructures]]
|
||||
deps = ["InteractiveUtils", "OrderedCollections"]
|
||||
git-tree-sha1 = "af6d9c86e191c917c2276fbede1137e8ea20157f"
|
||||
git-tree-sha1 = "6166ecfaf2b8bbf2b68d791bc1d54501f345d314"
|
||||
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
|
||||
version = "0.17.17"
|
||||
version = "0.17.15"
|
||||
|
||||
[[Dates]]
|
||||
deps = ["Printf"]
|
||||
@ -146,9 +146,9 @@ version = "0.1.1"
|
||||
|
||||
[[FillArrays]]
|
||||
deps = ["LinearAlgebra", "Random", "SparseArrays"]
|
||||
git-tree-sha1 = "44f561e293987ffc84272cd3d2b14b0b93123d63"
|
||||
git-tree-sha1 = "6c89d5b673e59b8173c546c84127e5f623d865f6"
|
||||
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
|
||||
version = "0.8.10"
|
||||
version = "0.8.9"
|
||||
|
||||
[[FixedPointNumbers]]
|
||||
git-tree-sha1 = "3ba9ea634d4c8b289d590403b4a06f8e227a6238"
|
||||
@ -173,9 +173,9 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
|
||||
|
||||
[[GPUArrays]]
|
||||
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
|
||||
git-tree-sha1 = "d887693eb1bd5e1fd573262a978745481895ec7d"
|
||||
git-tree-sha1 = "ce4579ebffef43e07318e9544ffeb6532c95d04d"
|
||||
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
|
||||
version = "3.4.1"
|
||||
version = "3.3.0"
|
||||
|
||||
[[GPUCompiler]]
|
||||
deps = ["Cthulhu", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "TimerOutputs"]
|
||||
@ -185,9 +185,9 @@ version = "0.2.0"
|
||||
|
||||
[[IRTools]]
|
||||
deps = ["InteractiveUtils", "MacroTools", "Test"]
|
||||
git-tree-sha1 = "90ee39f9beaaa186e4968417ea2b8ed5673c91c0"
|
||||
git-tree-sha1 = "8845400bd2d9815d37720251f1b53d27a335e1f4"
|
||||
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
|
||||
version = "0.3.3"
|
||||
version = "0.3.2"
|
||||
|
||||
[[InteractiveUtils]]
|
||||
deps = ["Markdown"]
|
||||
@ -195,15 +195,15 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
|
||||
|
||||
[[Juno]]
|
||||
deps = ["Base64", "Logging", "Media", "Profile"]
|
||||
git-tree-sha1 = "a686b0cf235fa3e491b79b4783c2d2382292b436"
|
||||
git-tree-sha1 = "e1ba2a612645b3e07c773c3a208f215745081fe6"
|
||||
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
version = "0.8.2"
|
||||
version = "0.8.1"
|
||||
|
||||
[[LLVM]]
|
||||
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
|
||||
git-tree-sha1 = "dd3f584c3dbefe39b2a8fbafa1a3b77e31e21255"
|
||||
git-tree-sha1 = "93d2e1e960fe47db1a9015e86fad1d47cf67cf59"
|
||||
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
|
||||
version = "1.5.1"
|
||||
version = "1.4.1"
|
||||
|
||||
[[LibGit2]]
|
||||
deps = ["Printf"]
|
||||
@ -319,9 +319,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
|
||||
|
||||
[[SpecialFunctions]]
|
||||
deps = ["OpenSpecFun_jll"]
|
||||
git-tree-sha1 = "d8d8b8a9f4119829410ecd706da4cc8594a1e020"
|
||||
git-tree-sha1 = "e19b98acb182567bcb7b75bb5d9eedf3a3b5ec6c"
|
||||
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
|
||||
version = "0.10.3"
|
||||
version = "0.10.0"
|
||||
|
||||
[[StaticArrays]]
|
||||
deps = ["LinearAlgebra", "Random", "Statistics"]
|
||||
@ -345,9 +345,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
||||
|
||||
[[TimerOutputs]]
|
||||
deps = ["Printf"]
|
||||
git-tree-sha1 = "f458ca23ff80e46a630922c555d838303e4b9603"
|
||||
git-tree-sha1 = "0cc8db57cb537191b02948d4fabdc09eb7f31f98"
|
||||
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
|
||||
version = "0.5.6"
|
||||
version = "0.5.5"
|
||||
|
||||
[[TranscodingStreams]]
|
||||
deps = ["Random", "Test"]
|
||||
@ -364,15 +364,15 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
|
||||
|
||||
[[ZipFile]]
|
||||
deps = ["Libdl", "Printf", "Zlib_jll"]
|
||||
git-tree-sha1 = "254975fef2fc526583bb9b7c9420fe66ffe09f2f"
|
||||
git-tree-sha1 = "8748302cfdec02c4ae9c97b112cf10003f7f767f"
|
||||
uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
|
||||
version = "0.9.2"
|
||||
version = "0.9.1"
|
||||
|
||||
[[Zlib_jll]]
|
||||
deps = ["Libdl", "Pkg"]
|
||||
git-tree-sha1 = "a2e0d558f6031002e380a90613b199e37a8565bf"
|
||||
git-tree-sha1 = "2f6c3e15e20e036ee0a0965879b31442b7ec50fa"
|
||||
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
|
||||
version = "1.2.11+10"
|
||||
version = "1.2.11+9"
|
||||
|
||||
[[Zygote]]
|
||||
deps = ["AbstractFFTs", "ArrayLayouts", "DiffRules", "FillArrays", "ForwardDiff", "Future", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
|
||||
|
5
NEWS.md
5
NEWS.md
@ -1,8 +1,3 @@
|
||||
# v0.11
|
||||
* Change to `DataLoader`'s constructor [https://github.com/FluxML/Flux.jl/pull/1152]
|
||||
* Use `DataLoader` with `NamedTuple`s, so that tensors can be accessed by name [https://github.com/FluxML/Flux.jl/pull/1221].
|
||||
* Error if Dense layers weights and biases are not arrays [https://github.com/FluxML/Flux.jl/pull/1218].
|
||||
|
||||
# v0.10.5
|
||||
* Add option for [same padding](https://github.com/FluxML/Flux.jl/pull/901) to conv and pooling layers by setting `pad=SamePad()`.
|
||||
* Added option to set `bias` to [Flux.Zeros](https://github.com/FluxML/Flux.jl/pull/873) to eliminating `bias` from being trained.
|
||||
|
@ -1,6 +1,6 @@
|
||||
name = "Flux"
|
||||
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||
version = "0.11.0-DEV"
|
||||
version = "0.10.5"
|
||||
|
||||
[deps]
|
||||
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||
@ -27,7 +27,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
|
||||
|
||||
[compat]
|
||||
AbstractTrees = "0.2, 0.3"
|
||||
Adapt = "1, 2.0"
|
||||
Adapt = "1"
|
||||
CodecZlib = "0.5, 0.6, 0.7"
|
||||
Colors = "0.8, 0.9, 0.10, 0.11, 0.12"
|
||||
CuArrays = "2"
|
||||
|
@ -7,15 +7,15 @@ julia> using Flux: onehot, onecold
|
||||
|
||||
julia> onehot(:b, [:a, :b, :c])
|
||||
3-element Flux.OneHotVector:
|
||||
0
|
||||
1
|
||||
0
|
||||
false
|
||||
true
|
||||
false
|
||||
|
||||
julia> onehot(:c, [:a, :b, :c])
|
||||
3-element Flux.OneHotVector:
|
||||
0
|
||||
0
|
||||
1
|
||||
false
|
||||
false
|
||||
true
|
||||
```
|
||||
|
||||
The inverse is `onecold` (which can take a general probability distribution, as well as just booleans).
|
||||
|
@ -39,7 +39,7 @@ E.g. the following will have run into the same problem as above:
|
||||
leaky_tanh(x) = 0.01*x + tanh(x)
|
||||
```
|
||||
|
||||
While one could change the activation function (e.g. to use `0.01f0*x`), the idiomatic (and safe way) to avoid type casts whenever inputs changes is to use `oftype`:
|
||||
While one could change the activation function (e.g. to use `0.01f0x`), the idiomatic (and safe way) to avoid type casts whenever inputs changes is to use `oftype`:
|
||||
```
|
||||
leaky_tanh(x) = oftype(x/1, 0.01)*x + tanh(x)
|
||||
```
|
||||
|
@ -142,7 +142,7 @@ function my_custom_train!(loss, ps, data, opt)
|
||||
for d in data
|
||||
gs = gradient(ps) do
|
||||
training_loss = loss(d...)
|
||||
# Insert whatever code you want here that needs Training loss, e.g. logging
|
||||
# Insert what ever code you want here that needs Training loss, e.g. logging
|
||||
return training_loss
|
||||
end
|
||||
# insert what ever code you want here that needs gradient
|
||||
|
@ -51,6 +51,4 @@ export Iris
|
||||
include("housing.jl")
|
||||
export Housing
|
||||
|
||||
@deprecate DataLoader(x...; kws...) DataLoader(x; kws...)
|
||||
|
||||
end
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Adapted from Knet's src/data.jl (author: Deniz Yuret)
|
||||
|
||||
struct DataLoader{D}
|
||||
data::D
|
||||
struct DataLoader
|
||||
data
|
||||
batchsize::Int
|
||||
nobs::Int
|
||||
partial::Bool
|
||||
@ -11,20 +11,21 @@ struct DataLoader{D}
|
||||
end
|
||||
|
||||
"""
|
||||
DataLoader(data; batchsize=1, shuffle=false, partial=true)
|
||||
DataLoader(data...; batchsize=1, shuffle=false, partial=true)
|
||||
|
||||
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
|
||||
(except possibly the last one).
|
||||
|
||||
Takes as input a single data tensor, or a tuple (or a named tuple) of tensors.
|
||||
The last dimension in each tensor is considered to be the observation dimension.
|
||||
Takes as input one or more data tensors, e.g. X in unsupervised learning, X and Y in
|
||||
supervised learning. The last dimension in each tensor is considered to be the observation
|
||||
dimension.
|
||||
|
||||
If `shuffle=true`, shuffles the observations each time iterations are re-started.
|
||||
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
|
||||
|
||||
The original data is preserved in the `data` field of the DataLoader.
|
||||
The original data is preserved as a tuple in the `data` field of the DataLoader.
|
||||
|
||||
Usage example:
|
||||
Example usage:
|
||||
|
||||
Xtrain = rand(10, 100)
|
||||
train_loader = DataLoader(Xtrain, batchsize=2)
|
||||
@ -36,16 +37,9 @@ Usage example:
|
||||
|
||||
train_loader.data # original dataset
|
||||
|
||||
# similar, but yielding tuples
|
||||
train_loader = DataLoader((Xtrain,), batchsize=2)
|
||||
for (x,) in train_loader
|
||||
@assert size(x) == (10, 2)
|
||||
...
|
||||
end
|
||||
|
||||
Xtrain = rand(10, 100)
|
||||
Ytrain = rand(100)
|
||||
train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
|
||||
train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
|
||||
for epoch in 1:100
|
||||
for (x, y) in train_loader
|
||||
@assert size(x) == (10, 2)
|
||||
@ -57,26 +51,26 @@ Usage example:
|
||||
# train for 10 epochs
|
||||
using IterTools: ncycle
|
||||
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
|
||||
|
||||
# can use NamedTuple to name tensors
|
||||
train_loader = DataLoader((images=Xtrain, labels=Ytrain), batchsize=2, shuffle=true)
|
||||
for datum in train_loader
|
||||
@assert size(datum.images) == (10, 2)
|
||||
@assert size(datum.labels) == (2,)
|
||||
end
|
||||
"""
|
||||
function DataLoader(data; batchsize=1, shuffle=false, partial=true)
|
||||
function DataLoader(data...; batchsize=1, shuffle=false, partial=true)
|
||||
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
||||
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
|
||||
|
||||
n = _nobs(data)
|
||||
if n < batchsize
|
||||
@warn "Number of observations less than batchsize, decreasing the batchsize to $n"
|
||||
batchsize = n
|
||||
nx = size(data[1])[end]
|
||||
for i=2:length(data)
|
||||
nx != size(data[i])[end] && throw(DimensionMismatch("All data should contain same number of observations"))
|
||||
end
|
||||
imax = partial ? n : n - batchsize + 1
|
||||
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle)
|
||||
if nx < batchsize
|
||||
@warn "Number of data points less than batchsize, decreasing the batchsize to $nx"
|
||||
batchsize = nx
|
||||
end
|
||||
imax = partial ? nx : nx - batchsize + 1
|
||||
ids = 1:min(nx, batchsize)
|
||||
DataLoader(data, batchsize, nx, partial, imax, [1:nx;], shuffle)
|
||||
end
|
||||
|
||||
getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids]
|
||||
|
||||
@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
|
||||
i >= d.imax && return nothing
|
||||
if d.shuffle && i == 0
|
||||
@ -84,7 +78,11 @@ end
|
||||
end
|
||||
nexti = min(i + d.batchsize, d.nobs)
|
||||
ids = d.indices[i+1:nexti]
|
||||
batch = _getobs(d.data, ids)
|
||||
if length(d.data) == 1
|
||||
batch = getdata(d.data[1], ids)
|
||||
else
|
||||
batch = ((getdata(x, ids) for x in d.data)...,)
|
||||
end
|
||||
return (batch, nexti)
|
||||
end
|
||||
|
||||
@ -92,19 +90,3 @@ function Base.length(d::DataLoader)
|
||||
n = d.nobs / d.batchsize
|
||||
d.partial ? ceil(Int,n) : floor(Int,n)
|
||||
end
|
||||
|
||||
_nobs(data::AbstractArray) = size(data)[end]
|
||||
|
||||
function _nobs(data::Union{Tuple, NamedTuple})
|
||||
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
||||
n = _nobs(data[1])
|
||||
if !all(x -> _nobs(x) == n, Base.tail(data))
|
||||
throw(DimensionMismatch("All data should contain same number of observations"))
|
||||
end
|
||||
return n
|
||||
end
|
||||
|
||||
_getobs(data::AbstractArray, i) = data[ntuple(i -> Colon(), Val(ndims(data) - 1))..., i]
|
||||
_getobs(data::Union{Tuple, NamedTuple}, i) = map(Base.Fix2(_getobs, i), data)
|
||||
|
||||
Base.eltype(::DataLoader{D}) where D = D
|
||||
|
@ -24,7 +24,7 @@ testmode!(m, mode = true) = m
|
||||
trainmode!(m, mode = true)
|
||||
|
||||
Set a layer of model's train mode (see below).
|
||||
Symmetric to [`testmode!`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)`).
|
||||
Symmetric to [`testmode!`](@ref) (i.e. `trainmode!(m, mode) == testmode!(m, !mode)).
|
||||
|
||||
_Note_: if you manually set a model into train mode, you need to manually place
|
||||
it into test mode during testing phase.
|
||||
|
@ -102,7 +102,7 @@ julia> d(rand(5))
|
||||
-0.16210233
|
||||
0.12311903```
|
||||
"""
|
||||
struct Dense{F,S<:AbstractArray,T<:AbstractArray}
|
||||
struct Dense{F,S,T}
|
||||
W::S
|
||||
b::T
|
||||
σ::F
|
||||
|
@ -132,7 +132,7 @@ end
|
||||
function (c::Conv)(x::AbstractArray)
|
||||
# TODO: breaks gpu broadcast :(
|
||||
# ndims(x) == ndims(c.weight)-1 && return squeezebatch(c(reshape(x, size(x)..., 1)))
|
||||
σ, b = c.σ, reshape(c.bias, ntuple(_->1, length(c.stride))..., :, 1)
|
||||
σ, b = c.σ, reshape(c.bias, map(_->1, c.stride)..., :, 1)
|
||||
cdims = DenseConvDims(x, c.weight; stride=c.stride, padding=c.pad, dilation=c.dilation)
|
||||
σ.(conv(x, c.weight, cdims) .+ b)
|
||||
end
|
||||
@ -222,7 +222,7 @@ end
|
||||
function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;
|
||||
init = glorot_uniform, stride = 1, pad = 0, dilation = 1,
|
||||
weight = convfilter(k, reverse(ch), init = init), bias = zeros(ch[2])) where N
|
||||
|
||||
|
||||
ConvTranspose(weight, bias, σ,
|
||||
stride = stride, pad = pad, dilation = dilation)
|
||||
end
|
||||
|
@ -46,10 +46,9 @@ given the prediction `ŷ` and true values `y`.
|
||||
Huber loss = |
|
||||
| δ * (|ŷ - y| - 0.5 * δ), otherwise
|
||||
"""
|
||||
#TODO: remove dropgrad when Zygote can handle this function with CuArrays
|
||||
function huber_loss(ŷ, y; δ=eltype(ŷ)(1))
|
||||
abs_error = abs.(ŷ .- y)
|
||||
temp = Zygote.dropgrad(abs_error .< δ)
|
||||
temp = abs_error .< δ
|
||||
x = eltype(ŷ)(0.5)
|
||||
hub_loss = sum(((abs_error.^2) .* temp) .* x .+ δ*(abs_error .- x*δ) .* (1 .- temp)) * 1 // length(y)
|
||||
end
|
||||
|
@ -27,8 +27,7 @@ Base.getindex(xs::OneHotMatrix, ::Colon, ::Colon) = OneHotMatrix(xs.height, copy
|
||||
|
||||
Base.getindex(xs::OneHotMatrix, i::Integer, ::Colon) = map(x -> x[i], xs.data)
|
||||
|
||||
# remove workaround when https://github.com/JuliaGPU/CuArrays.jl/issues/676 is fixed
|
||||
A::AbstractMatrix * B::OneHotMatrix = A[:, cpu(map(x->x.ix, B.data))]
|
||||
A::AbstractMatrix * B::OneHotMatrix = A[:, map(x->x.ix, B.data)]
|
||||
|
||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix(length(x), [x, xs...])
|
||||
|
||||
@ -49,7 +48,7 @@ cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.d
|
||||
Create a `OneHotVector` with its `l`-th element `true` based on the
|
||||
possible set of `labels`.
|
||||
If `unk` is given, return `onehot(unk, labels)` if the input label `l` is not found
|
||||
in `labels`; otherwise, it will raise an error.
|
||||
in `labels`; otherwise it will error.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
|
@ -68,7 +68,8 @@ and compute the gradient of `loss(d)`.
|
||||
A callback is given with the keyword argument `cb`. For example, this will print
|
||||
"training" every 10 seconds (using [`Flux.throttle`](@ref)):
|
||||
|
||||
train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
|
||||
train!(loss, params, data, opt,
|
||||
cb = throttle(() -> println("training"), 10))
|
||||
|
||||
The callback can call [`Flux.stop`](@ref) to interrupt the training loop.
|
||||
|
||||
|
@ -246,10 +246,6 @@ function _restructure(m, xs)
|
||||
end
|
||||
end
|
||||
|
||||
@adjoint function _restructure(m, xs)
|
||||
_restructure(m, xs), dm -> (nothing,destructure(dm)[1])
|
||||
end
|
||||
|
||||
"""
|
||||
destructure(m)
|
||||
|
||||
|
34
test/data.jl
34
test/data.jl
@ -3,34 +3,20 @@
|
||||
Y = [1:5;]
|
||||
|
||||
d = DataLoader(X, batchsize=2)
|
||||
@inferred first(d)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == typeof(X)
|
||||
@test length(batches) == 3
|
||||
@test batches[1] == X[:,1:2]
|
||||
@test batches[2] == X[:,3:4]
|
||||
@test batches[3] == X[:,5:5]
|
||||
|
||||
d = DataLoader(X, batchsize=2, partial=false)
|
||||
@inferred first(d)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == typeof(X)
|
||||
@test length(batches) == 2
|
||||
@test batches[1] == X[:,1:2]
|
||||
@test batches[2] == X[:,3:4]
|
||||
|
||||
d = DataLoader((X,), batchsize=2, partial=false)
|
||||
@inferred first(d)
|
||||
d = DataLoader(X, Y, batchsize=2)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
|
||||
@test length(batches) == 2
|
||||
@test batches[1] == (X[:,1:2],)
|
||||
@test batches[2] == (X[:,3:4],)
|
||||
|
||||
d = DataLoader((X, Y), batchsize=2)
|
||||
@inferred first(d)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
|
||||
@test length(batches) == 3
|
||||
@test length(batches[1]) == 2
|
||||
@test length(batches[2]) == 2
|
||||
@ -42,22 +28,6 @@
|
||||
@test batches[3][1] == X[:,5:5]
|
||||
@test batches[3][2] == Y[5:5]
|
||||
|
||||
# test with NamedTuple
|
||||
d = DataLoader((x=X, y=Y), batchsize=2)
|
||||
@inferred first(d)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
|
||||
@test length(batches) == 3
|
||||
@test length(batches[1]) == 2
|
||||
@test length(batches[2]) == 2
|
||||
@test length(batches[3]) == 2
|
||||
@test batches[1][1] == batches[1].x == X[:,1:2]
|
||||
@test batches[1][2] == batches[1].y == Y[1:2]
|
||||
@test batches[2][1] == batches[2].x == X[:,3:4]
|
||||
@test batches[2][2] == batches[2].y == Y[3:4]
|
||||
@test batches[3][1] == batches[3].x == X[:,5:5]
|
||||
@test batches[3][2] == batches[3].y == Y[5:5]
|
||||
|
||||
# test interaction with `train!`
|
||||
θ = ones(2)
|
||||
X = zeros(2, 10)
|
||||
@ -71,7 +41,7 @@
|
||||
X = ones(2, 10)
|
||||
Y = fill(2, 10)
|
||||
loss(x, y) = sum((y - x'*θ).^2)
|
||||
d = DataLoader((X, Y))
|
||||
d = DataLoader(X, Y)
|
||||
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
|
||||
@test norm(θ .- 1) < 1e-10
|
||||
end
|
||||
|
@ -28,14 +28,6 @@ import Flux: activations
|
||||
end
|
||||
|
||||
@testset "Dense" begin
|
||||
@testset "constructors" begin
|
||||
@test size(Dense(10, 100).W) == (100, 10)
|
||||
@test Dense(rand(100,10), rand(10)).σ == identity
|
||||
|
||||
@test_throws MethodError Dense(10, 10.5)
|
||||
@test_throws MethodError Dense(10, 10.5, tanh)
|
||||
end
|
||||
|
||||
@test length(Dense(10, 5)(randn(10))) == 5
|
||||
@test_throws DimensionMismatch Dense(10, 5)(randn(1))
|
||||
@test_throws MethodError Dense(10, 5)(1) # avoid broadcasting
|
||||
@ -45,6 +37,7 @@ import Flux: activations
|
||||
@test Dense(10, 1, identity, initW = ones, initb = zeros)(ones(10,2)) == 10*ones(1, 2)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)(ones(10,1)) == 10*ones(2, 1)
|
||||
@test Dense(10, 2, identity, initW = ones, initb = zeros)([ones(10,1) 2*ones(10,1)]) == [10 20; 10 20]
|
||||
|
||||
end
|
||||
|
||||
@testset "Diagonal" begin
|
||||
|
@ -2,45 +2,49 @@ using Flux
|
||||
using Flux.Data
|
||||
using Test
|
||||
using Random, Statistics, LinearAlgebra
|
||||
using Documenter
|
||||
using IterTools: ncycle
|
||||
|
||||
Random.seed!(0)
|
||||
|
||||
@testset "Utils" begin
|
||||
include("utils.jl")
|
||||
end
|
||||
@testset "Flux" begin
|
||||
|
||||
@testset "Onehot" begin
|
||||
include("onehot.jl")
|
||||
end
|
||||
|
||||
@testset "Optimise" begin
|
||||
include("optimise.jl")
|
||||
end
|
||||
|
||||
@testset "Data" begin
|
||||
include("data.jl")
|
||||
end
|
||||
|
||||
@testset "Layers" begin
|
||||
include("layers/basic.jl")
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("layers/conv.jl")
|
||||
end
|
||||
|
||||
@testset "CUDA" begin
|
||||
if Flux.use_cuda[]
|
||||
include("cuda/cuda.jl")
|
||||
else
|
||||
@warn "CUDA unavailable, not testing GPU support"
|
||||
@testset "Utils" begin
|
||||
include("utils.jl")
|
||||
end
|
||||
|
||||
@testset "Onehot" begin
|
||||
include("onehot.jl")
|
||||
end
|
||||
|
||||
@testset "Optimise" begin
|
||||
include("optimise.jl")
|
||||
end
|
||||
|
||||
@testset "Data" begin
|
||||
include("data.jl")
|
||||
end
|
||||
|
||||
@testset "Layers" begin
|
||||
include("layers/basic.jl")
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("layers/conv.jl")
|
||||
end
|
||||
|
||||
@testset "CUDA" begin
|
||||
if Flux.use_cuda[]
|
||||
include("cuda/cuda.jl")
|
||||
else
|
||||
@warn "CUDA unavailable, not testing GPU support"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@static if VERSION >= v"1.4"
|
||||
using Documenter
|
||||
@testset "Docs" begin
|
||||
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
||||
doctest(Flux)
|
||||
if VERSION >= v"1.4"
|
||||
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
||||
doctest(Flux)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
end # testset Flux
|
||||
|
Loading…
Reference in New Issue
Block a user