Flux.jl/src/utils.jl

156 lines
3.4 KiB
Julia
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Arrays
initn(dims...) = randn(dims...)/100
glorot_uniform(dims...) = (rand(dims...) .- 0.5) .* sqrt(24.0/(sum(dims)))
glorot_normal(dims...) = randn(dims...) .* sqrt(2.0/sum(dims))
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...)
unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
"""
chunk(xs, n)
Split `xs` into `n` parts.
```julia
julia> chunk(1:10, 3)
3-element Array{Array{Int64,1},1}:
[1, 2, 3, 4]
[5, 6, 7, 8]
[9, 10]
```
"""
chunk(xs, n) = collect(Iterators.partition(xs, ceil(Int, length(xs)/n)))
batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)
"""
frequencies(xs)
Count the number of times that each element of `xs` appears.
```julia
julia> frequencies(['a','b','b'])
Dict{Char,Int64} with 2 entries:
'b' => 2
'a' => 1
```
"""
function frequencies(xs)
fs = Dict{eltype(xs),Int}()
for x in xs
fs[x] = get(fs, x, 0) + 1
end
return fs
end
head(x::Tuple) = reverse(Base.tail(reverse(x)))
squeezebatch(x) = reshape(x, head(size(x)))
"""
batch(xs)
Batch the arrays in `xs` into a single array.
```julia
julia> batch([[1,2,3],[4,5,6]])
3×2 Array{Int64,2}:
1 4
2 5
3 6
```
"""
function batch(xs)
data = first(xs) isa AbstractArray ?
similar(first(xs), size(first(xs))..., length(xs)) :
Vector{eltype(xs)}(undef, length(xs))
for (i, x) in enumerate(xs)
data[batchindex(data, i)...] = x
end
return data
end
Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))]
"""
batchseq(seqs, pad)
Take a list of `N` sequences, and turn them into a single sequence where each
item is a batch of `N`. Short sequences will be padded by `pad`.
```julia
julia> batchseq([[1, 2, 3], [4, 5]], 0)
3-element Array{Array{Int64,1},1}:
[1, 4]
[2, 5]
[3, 0]
```
"""
function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs))
xs_ = [rpad(x, n, pad) for x in xs]
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
end
# Other
"""
Returns a function that when invoked, will only be triggered at most once
during `timeout` seconds. Normally, the throttled function will run
as much as it can, without ever going more than once per `wait` duration;
but if you'd like to disable the execution on the leading edge, pass
`leading=false`. To enable execution on the trailing edge, ditto.
"""
function throttle(f, timeout; leading=true, trailing=false)
cooldown = true
later = nothing
result = nothing
function throttled(args...; kwargs...)
yield()
if cooldown
if leading
result = f(args...; kwargs...)
else
later = () -> f(args...; kwargs...)
end
cooldown = false
@async try
while (sleep(timeout); later != nothing)
later()
later = nothing
end
finally
cooldown = true
end
elseif trailing
later = () -> (result = f(args...; kwargs...))
end
return result
end
end
"""
J = jacobian(m,x)
Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])`
"""
function jacobian(m,x)
xp = param(x)
y = m(xp)
k = length(y)
n = length(x)
J = Matrix{eltype(x)}(undef,n,k)
for i = 1:k
Flux.back!(y[i]) # Populate gradient accumulator
J[:,i] = xp.grad
xp.grad .*= 0 # Reset gradient accumulator
end
J'
end