![]() 882: Check if CUDA availability changed during init. r=MikeInnes a=maleadt With this PR, Flux checks using CUDAapi if CUDA is available during initialization, and forces recompilation if that does not agree with what was decided during precompilation. This avoids the scenario where Flux was precompiled without GPU support, consequently not allowing use of the GPU even if the user fixed his CUDA/GPU set-up because that does not force recompilation (and we can't add precompilation dependencies on stuff that doesn't exist). However, we can't do the same for the case where we have a GPU/CUDA but CuArrays fails to import (checking if it imports during `__init__` would be much too expensive, if even possible), so this PR removes support for having CUDA/a GPU but CuArrays being broken. That's a little risky now that Flux depends on CuArrays, but the package is pretty mature and I haven't seen many bug reports failing to load it recently. Fixes https://github.com/FluxML/Flux.jl/pull/852#issuecomment-538028314 cc @MikeInnes @xukai92 Co-authored-by: Tim Besard <tim.besard@gmail.com> |
||
---|---|---|
.github | ||
docs | ||
paper | ||
src | ||
test | ||
.gitattributes | ||
.gitignore | ||
.gitlab-ci.yml | ||
.travis.yml | ||
bors.toml | ||
CITATION.bib | ||
LICENSE.md | ||
Manifest.toml | ||
NEWS.md | ||
Project.toml | ||
README.md |
Flux is an elegant approach to machine learning. It's a 100% pure-Julia stack, and provides lightweight abstractions on top of Julia's native GPU and AD support. Flux makes the easy things easy while remaining fully hackable.
julia> Pkg.add("Flux")
See the documentation or the model zoo for examples.
If you use Flux in research, please cite the following paper:
@article{innes:2018,
author = {Mike Innes},
title = {Flux: Elegant Machine Learning with Julia},
journal = {Journal of Open Source Software},
year = {2018},
doi = {10.21105/joss.00602},
}
Features
Flux has powerful high-level features, and common architectures can be defined in a few lines.
model = Chain(
Dense(768, 128, σ),
LSTM(128, 256),
LSTM(256, 128),
Dense(128, 10),
softmax)
loss(x, y) = crossentropy(model(x), y)
Flux.train!(loss, data, ADAM(...))
Yet you can easily strip away the layers, and directly write the mathematics for your problem. Flux will seamlessly take gradients of any Julia code, so your model looks just like the paper.
W = param(randn(2, 10))
b = param(randn(2))
y(x) = σ.(W * x .+ b)
If that's still not enough, you can go as deep as you want, even writing your own CUDA kernels with CUDAnative! All this can be freely mixed-and-matched in a single model or script, and it all runs interactively via Jupyter or Juno.
function gpu_add(a, b, c)
i = (blockIdx().x-1) * blockDim().x + threadIdx().x
c[i] = a[i] + b[i]
return nothing
end
Unusual architectures are no problem in Flux, as you can use all the loops, control flow and even macros that you're used to. Here's a Tree RNN in 4 lines.
tree() = rand() < 0.5 ? rand(10) : (tree(), tree()) # dummy data
shrink = Dense(20, 10)
combine(a, b) = shrink([a; b])
model(x) = x
model(x::Tuple) = combine(model(x[1]), model(x[2]))
model(tree()) # Sample output
Despite this flexibility, Julia's advanced compiler lets us do some powerful optimisations. For example, this definition of sigmoid
automatically gets fused into a single GPU kernel – so it's really fast.
sigmoid(xs) = 1 ./ (1 .+ exp.(.-xs))
Similarly, Flux is the first dynamic framework to support compiling to the browser and model import via formats like ONNX, both of which are thinly-veiled compiler problems.
For more on our philosophy on machine learning, check out our article On Machine Learning & Programming Languages.
Contributing & Help
For general questions and help, check out Julia's community forum.
Flux development is carried out via our GitHub issues, so feel free to open feature requests or PRs here.
For more informal discussions we'd love to have you on the Julia slack, where we hang out on the #machine-learning channel.
Related Packages
Check out Metalhead.jl for common computer vision datasets and trained models.
MLDatasets.jl provides further common datasets.