Add missing docstrings to `src/utils.jl`
Not sure about the `stack`, `unstack` and `unsqueeze` functions.
This commit is contained in:
parent
2f955a33cd
commit
c222e1b124
121
src/utils.jl
121
src/utils.jl
|
@ -4,7 +4,37 @@ nfan(n) = 1, n #A vector is treated as a n×1 matrix
|
|||
nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices
|
||||
nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of convolution kernels
|
||||
|
||||
"""
|
||||
glorot_uniform(dims...)
|
||||
|
||||
Return an `Array` of size `dims` containing random variables taken from a uniform
|
||||
distribution in the interval ``[-x, x]``, where `x = sqrt(24 / sum(dims)) / 2`.
|
||||
|
||||
# Examples
|
||||
```jldoctest; setup = :(using Random; Random.seed!(0))
|
||||
julia> Flux.glorot_uniform(2, 3)
|
||||
2×3 Array{Float32,2}:
|
||||
0.601094 -0.57414 -0.814925
|
||||
0.900868 0.805994 0.057514
|
||||
```
|
||||
"""
|
||||
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
|
||||
|
||||
"""
|
||||
glorot_normal(dims...)
|
||||
|
||||
Return an `Array` of size `dims` containing random variables taken from a normal
|
||||
distribution with mean 0 and standard deviation `(2 / sum(dims))`.
|
||||
|
||||
# Examples
|
||||
```jldoctest; setup = :(using Random; Random.seed!(0))
|
||||
julia> Flux.glorot_normal(3, 2)
|
||||
3×2 Array{Float32,2}:
|
||||
0.429505 -0.0852891
|
||||
0.523935 0.371009
|
||||
-0.223261 0.188052
|
||||
```
|
||||
"""
|
||||
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
|
||||
|
||||
ones(T::Type, dims...) = Base.ones(T, dims...)
|
||||
|
@ -13,9 +43,81 @@ zeros(T::Type, dims...) = Base.zeros(T, dims...)
|
|||
ones(dims...) = Base.ones(Float32, dims...)
|
||||
zeros(dims...) = Base.zeros(Float32, dims...)
|
||||
|
||||
"""
|
||||
unsqueeze(xs, dim)
|
||||
|
||||
Return `xs` reshaped into an `Array` one dimensionality higher than `xs`,
|
||||
where `dim` indicates in which dimension `xs` is extended.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> xs = [[1, 2], [3, 4], [5, 6]]
|
||||
3-element Array{Array{Int64,1},1}:
|
||||
[1, 2]
|
||||
[3, 4]
|
||||
[5, 6]
|
||||
|
||||
julia> Flux.unsqueeze(xs, 1)
|
||||
1×3 Array{Array{Int64,1},2}:
|
||||
[1, 2] [3, 4] [5, 6]
|
||||
|
||||
julia> Flux.unsqueeze([1 2; 3 4], 2)
|
||||
2×1×2 Array{Int64,3}:
|
||||
[:, :, 1] =
|
||||
1
|
||||
3
|
||||
|
||||
[:, :, 2] =
|
||||
2
|
||||
4
|
||||
```
|
||||
"""
|
||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||
|
||||
"""
|
||||
stack(xs, dim)
|
||||
|
||||
Concatenate the given `Array` of `Array`s `xs` into a single `Array` along the
|
||||
given dimension `dim`.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> xs = [[1, 2], [3, 4], [5, 6]]
|
||||
3-element Array{Array{Int64,1},1}:
|
||||
[1, 2]
|
||||
[3, 4]
|
||||
[5, 6]
|
||||
|
||||
julia> Flux.stack(xs, 1)
|
||||
3×2 Array{Int64,2}:
|
||||
1 2
|
||||
3 4
|
||||
5 6
|
||||
|
||||
julia> cat(xs, dims=1)
|
||||
3-element Array{Array{Int64,1},1}:
|
||||
[1, 2]
|
||||
[3, 4]
|
||||
[5, 6]
|
||||
```
|
||||
"""
|
||||
stack(xs, dim) = cat(unsqueeze.(xs, dim)..., dims=dim)
|
||||
|
||||
"""
|
||||
unstack(xs, dim)
|
||||
|
||||
Unroll the given `xs` into an `Array` of `Array`s along the given dimension `dim`.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> Flux.unstack([1 3 5 7; 2 4 6 8], 2)
|
||||
4-element Array{Array{Int64,1},1}:
|
||||
[1, 2]
|
||||
[3, 4]
|
||||
[5, 6]
|
||||
[7, 8]
|
||||
```
|
||||
"""
|
||||
unstack(xs, dim) = [copy(selectdim(xs, dim, i)) for i in 1:size(xs, dim)]
|
||||
|
||||
"""
|
||||
|
@ -82,6 +184,25 @@ function batch(xs)
|
|||
return data
|
||||
end
|
||||
|
||||
"""
|
||||
Return the given sequence padded with `p` up to a maximum length of `n`.
|
||||
|
||||
# Examples
|
||||
```jldoctest
|
||||
julia> rpad([1, 2], 4, 0)
|
||||
4-element Array{Int64,1}:
|
||||
1
|
||||
2
|
||||
0
|
||||
0
|
||||
|
||||
julia> rpad([1, 2, 3], 2, 0)
|
||||
3-element Array{Int64,1}:
|
||||
1
|
||||
2
|
||||
3
|
||||
```
|
||||
"""
|
||||
Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))]
|
||||
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue