Merge branch 'master' into cat-fix
This commit is contained in:
commit
2084df96ae
@ -14,8 +14,8 @@ makedocs(modules=[Flux],
|
|||||||
"Training Models" =>
|
"Training Models" =>
|
||||||
["Optimisers" => "training/optimisers.md",
|
["Optimisers" => "training/optimisers.md",
|
||||||
"Training" => "training/training.md"],
|
"Training" => "training/training.md"],
|
||||||
"Data Munging" =>
|
"One-Hot Encoding" => "data/onehot.md",
|
||||||
["One-Hot Encoding" => "data/onehot.md"],
|
"GPU Support" => "gpu.md",
|
||||||
"Contributing & Help" => "contributing.md"])
|
"Contributing & Help" => "contributing.md"])
|
||||||
|
|
||||||
deploydocs(
|
deploydocs(
|
||||||
|
35
docs/src/gpu.md
Normal file
35
docs/src/gpu.md
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# GPU Support
|
||||||
|
|
||||||
|
Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CuArrays](https://github.com/JuliaGPU/CuArrays.jl) and [CLArrays](https://github.com/JuliaGPU/CLArrays.jl). Flux doesn't care what array type you use, so we can just plug these in without any other changes.
|
||||||
|
|
||||||
|
For example, we can use `CuArrays` (with the `cu` converter) to run our [basic example](models/basics.md) on an NVIDIA GPU.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
using CuArrays
|
||||||
|
|
||||||
|
W = cu(rand(2, 5)) # a 2×5 CuArray
|
||||||
|
b = cu(rand(2))
|
||||||
|
|
||||||
|
predict(x) = W*x .+ b
|
||||||
|
loss(x, y) = sum((predict(x) .- y).^2)
|
||||||
|
|
||||||
|
x, y = cu(rand(5)), cu(rand(2)) # Dummy data
|
||||||
|
loss(x, y) # ~ 3
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that we convert both the parameters (`W`, `b`) and the data set (`x`, `y`) to cuda arrays. Taking derivatives and training works exactly as before.
|
||||||
|
|
||||||
|
If you define a structured model, like a `Dense` layer or `Chain`, you just need to convert the internal parameters. Flux provides `mapparams`, which allows you to alter all parameters of a model at once.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
d = Dense(10, 5, σ)
|
||||||
|
d = mapparams(cu, d)
|
||||||
|
d.W # Tracked CuArray
|
||||||
|
d(cu(rand(10))) # CuArray output
|
||||||
|
|
||||||
|
m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
|
||||||
|
m = mapparams(cu, m)
|
||||||
|
d(cu(rand(10)))
|
||||||
|
```
|
||||||
|
|
||||||
|
The [mnist example](https://github.com/FluxML/model-zoo/blob/master/mnist/mnist.jl) contains the code needed to run the model on the GPU; just uncomment the lines after `using CuArrays`.
|
@ -2,14 +2,14 @@
|
|||||||
|
|
||||||
To actually train a model we need three things:
|
To actually train a model we need three things:
|
||||||
|
|
||||||
* A *loss function*, that evaluates how well a model is doing given some input data.
|
* A *model loss function*, that evaluates how well a model is doing given some input data.
|
||||||
* A collection of data points that will be provided to the loss function.
|
* A collection of data points that will be provided to the loss function.
|
||||||
* An [optimiser](optimisers.md) that will update the model parameters appropriately.
|
* An [optimiser](optimisers.md) that will update the model parameters appropriately.
|
||||||
|
|
||||||
With these we can call `Flux.train!`:
|
With these we can call `Flux.train!`:
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
Flux.train!(loss, data, opt)
|
Flux.train!(modelLoss, data, opt)
|
||||||
```
|
```
|
||||||
|
|
||||||
There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo).
|
There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo).
|
||||||
@ -23,7 +23,11 @@ m = Chain(
|
|||||||
Dense(784, 32, σ),
|
Dense(784, 32, σ),
|
||||||
Dense(32, 10), softmax)
|
Dense(32, 10), softmax)
|
||||||
|
|
||||||
|
# Model loss function
|
||||||
loss(x, y) = Flux.mse(m(x), y)
|
loss(x, y) = Flux.mse(m(x), y)
|
||||||
|
|
||||||
|
# later
|
||||||
|
Flux.train!(loss, data, opt)
|
||||||
```
|
```
|
||||||
|
|
||||||
The loss will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `logloss` for cross entropy loss, but you can calculate it however you want.
|
The loss will almost always be defined in terms of some *cost function* that measures the distance of the prediction `m(x)` from the target `y`. Flux has several of these built in, like `mse` for mean squared error or `logloss` for cross entropy loss, but you can calculate it however you want.
|
||||||
|
@ -4,11 +4,11 @@ module Flux
|
|||||||
|
|
||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
|
||||||
using Juno
|
using Juno, Requires
|
||||||
using Lazy: @forward
|
using Lazy: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM,
|
export Chain, Dense, RNN, LSTM,
|
||||||
SGD, params
|
SGD, params, mapparams
|
||||||
|
|
||||||
using NNlib
|
using NNlib
|
||||||
export σ, relu, softmax
|
export σ, relu, softmax
|
||||||
@ -21,8 +21,7 @@ using .Optimise
|
|||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("onehot.jl")
|
include("onehot.jl")
|
||||||
|
include("tree.jl")
|
||||||
include("compiler/Compiler.jl")
|
|
||||||
|
|
||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
include("layers/basic.jl")
|
include("layers/basic.jl")
|
||||||
|
@ -1,14 +0,0 @@
|
|||||||
module Compiler
|
|
||||||
|
|
||||||
using MacroTools, DataFlow, DataFlow.Interpreter
|
|
||||||
|
|
||||||
using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
|
|
||||||
iscyclic, Constant, constant, isconstant, group, Split,
|
|
||||||
detuple, value, inputs, thread!, value, inputs, inputnode,
|
|
||||||
spliceinputs, bumpinputs, Line, Frame, applylines, graphinputs
|
|
||||||
|
|
||||||
include("code.jl")
|
|
||||||
include("interp.jl")
|
|
||||||
include("loops.jl")
|
|
||||||
|
|
||||||
end
|
|
@ -1,77 +0,0 @@
|
|||||||
import DataFlow: cse
|
|
||||||
using MacroTools: @q, @>
|
|
||||||
|
|
||||||
graph(m) = nothing
|
|
||||||
|
|
||||||
function graphdef(ex, params = [])
|
|
||||||
@capture(shortdef(ex), (args__,) -> body_)
|
|
||||||
body = @> body MacroTools.flatten liftloops graphm DataFlow.il
|
|
||||||
body = map(x -> x in params ? :(self.$x) : x, body)
|
|
||||||
return args, body
|
|
||||||
end
|
|
||||||
|
|
||||||
function makegraph(graph, args, params = [])
|
|
||||||
graph = prewalk(graph) do v
|
|
||||||
isconstant(v) && (i = findfirst(args, value(v[1]))) ≠ 0 ?
|
|
||||||
inputnode(i) :
|
|
||||||
v
|
|
||||||
end
|
|
||||||
graph = map(graph) do x
|
|
||||||
x isa Offset ?
|
|
||||||
:(Flux.Compiler.Offset($(Expr(:quote, x.name)), $(x.n),
|
|
||||||
$(x.name in params ? :(self.$(x.name)) : x.name))) :
|
|
||||||
x
|
|
||||||
end
|
|
||||||
vertex(:($DataFlow.Frame(self)), graph)
|
|
||||||
end
|
|
||||||
|
|
||||||
function build_type(T, params)
|
|
||||||
@esc T
|
|
||||||
:(type $T
|
|
||||||
$(params...)
|
|
||||||
end)
|
|
||||||
end
|
|
||||||
|
|
||||||
function build_forward(body, args)
|
|
||||||
iscyclic(body) && return :(error("Can't run forward pass on a cyclic graph"))
|
|
||||||
applylines(syntax(cse(body)))
|
|
||||||
end
|
|
||||||
|
|
||||||
import Lazy: groupby
|
|
||||||
|
|
||||||
# TODO: type hints for parameters
|
|
||||||
|
|
||||||
function process_type(ex)
|
|
||||||
@capture(ex, type T_ fs__ end)
|
|
||||||
@destruct [params = false || [],
|
|
||||||
funcs = true || []] = groupby(x->isexpr(x, :->, :function), fs)
|
|
||||||
@assert length(funcs) == 1
|
|
||||||
pnames = namify.(params)
|
|
||||||
args, body = graphdef(funcs[1], pnames)
|
|
||||||
self = esc(:self)
|
|
||||||
quote
|
|
||||||
$(build_type(T, params))
|
|
||||||
$(esc(:((self::$T)($(args...)) = $(build_forward(body, args)))))
|
|
||||||
$(esc(:(Flux.Compiler.graph(self::$T)))) = $(DataFlow.constructor(map(esc, makegraph(body, args, params))))
|
|
||||||
nothing
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function process_anon(ex)
|
|
||||||
args, body = graphdef(ex)
|
|
||||||
:(Capacitor($(DataFlow.constructor(map(esc, makegraph(body, args)[1])))))
|
|
||||||
end
|
|
||||||
|
|
||||||
function process_def(ex)
|
|
||||||
# TODO: make a singleton net type
|
|
||||||
@capture(ex, f_(xs__) = body_)
|
|
||||||
:($(esc(f)) = @net $(esc(:(($(xs...),) -> $body))); nothing)
|
|
||||||
end
|
|
||||||
|
|
||||||
macro net(ex)
|
|
||||||
ex = shortdef(ex)
|
|
||||||
isexpr(ex, :type) ? process_type(ex) :
|
|
||||||
@capture(ex, (__,) -> _) ? process_anon(ex) :
|
|
||||||
@capture(ex, _(__) = _) ? process_def(ex) :
|
|
||||||
error("Unsupported model expression $ex")
|
|
||||||
end
|
|
@ -1,39 +0,0 @@
|
|||||||
function astuple(xs::Vertex)
|
|
||||||
isconstant(xs) && value(xs[1]) isa Tuple ? value(xs[1]) :
|
|
||||||
xs isa Vertex && value(xs) == tuple ? inputs(xs) :
|
|
||||||
nothing
|
|
||||||
end
|
|
||||||
|
|
||||||
astuple(xs::Tuple) = xs
|
|
||||||
|
|
||||||
astuple(xs) = nothing
|
|
||||||
|
|
||||||
function astuples(xs)
|
|
||||||
xs = [astuple(x) for x in xs]
|
|
||||||
all(x->!(x==nothing), xs) ? xs : nothing
|
|
||||||
end
|
|
||||||
|
|
||||||
function interp(ctx, f, xs...)
|
|
||||||
g = graph(f)
|
|
||||||
g ≠ nothing && iscyclic(g) && error("Can't interpret cyclic graph")
|
|
||||||
@icatch(ctx, g ≠ nothing ?
|
|
||||||
interpret(ctx, g, xs...) :
|
|
||||||
f(xs...))
|
|
||||||
end
|
|
||||||
|
|
||||||
function interpmodel(m, args...)
|
|
||||||
ctx = Context(mux(iconst, iline, ilambda, iargs, ituple, interp))
|
|
||||||
@ithrow interp(ctx, m, args...)
|
|
||||||
end
|
|
||||||
|
|
||||||
# Anonymous models
|
|
||||||
|
|
||||||
struct Capacitor
|
|
||||||
graph::IVertex{Any}
|
|
||||||
end
|
|
||||||
|
|
||||||
(m::Capacitor)(xs...) = interpmodel(m, xs...)
|
|
||||||
|
|
||||||
graph(cap::Capacitor) = cap.graph
|
|
||||||
|
|
||||||
Base.show(io::IO, ::Capacitor) = print(io, "Capacitor(...)")
|
|
@ -1,191 +0,0 @@
|
|||||||
using ..Flux: stack, unstack, squeeze, unsqueeze
|
|
||||||
|
|
||||||
# Stateful Models
|
|
||||||
|
|
||||||
mutable struct Stateful
|
|
||||||
model
|
|
||||||
states::Vector{Any}
|
|
||||||
istate::Vector{Any}
|
|
||||||
ostate::Vector{Any}
|
|
||||||
end
|
|
||||||
|
|
||||||
Stateful(model, ss) = Stateful(model, ss, ss, ss)
|
|
||||||
|
|
||||||
function Base.show(io::IO, m::Stateful)
|
|
||||||
print(io, "Stateful(")
|
|
||||||
show(io, m.model)
|
|
||||||
print(io, ")")
|
|
||||||
end
|
|
||||||
|
|
||||||
function (m::Stateful)(xs...)
|
|
||||||
m.istate = m.ostate
|
|
||||||
state, y = m.model((m.istate...,), xs...)
|
|
||||||
m.ostate = collect(state)
|
|
||||||
return y
|
|
||||||
end
|
|
||||||
|
|
||||||
# Seq Models
|
|
||||||
|
|
||||||
struct SeqModel
|
|
||||||
model
|
|
||||||
steps::Int
|
|
||||||
end
|
|
||||||
|
|
||||||
seqtuple(x, n) = x
|
|
||||||
seqtuple(xs::Tuple, n) = seqtuple.(xs, n)
|
|
||||||
|
|
||||||
seqtuple(xs::AbstractArray, n) =
|
|
||||||
ndims(xs) < 3 ? xs :
|
|
||||||
n ≠ 0 && size(xs, 2) ≠ n ? error("Expecting sequence length $n, got $(size(xs, 2))") :
|
|
||||||
(unstack(xs, 2)...)
|
|
||||||
|
|
||||||
reseq(x) = x
|
|
||||||
reseq(x::Tuple{}) = ()
|
|
||||||
reseq(xs::Tuple) = all(isa.(xs, AbstractArray) .& (ndims.(xs) .≥ 2)) ? stack(xs, 2) : reseq.(xs)
|
|
||||||
|
|
||||||
function (m::SeqModel)(xs...)
|
|
||||||
xs = seqtuple(xs, m.steps)
|
|
||||||
reseq(m.model(xs...))
|
|
||||||
end
|
|
||||||
|
|
||||||
graph(m::SeqModel) = graph(m.model)
|
|
||||||
|
|
||||||
# Recurrent Graphs
|
|
||||||
|
|
||||||
struct Offset
|
|
||||||
name::Symbol
|
|
||||||
n::Int
|
|
||||||
default::Nullable{Any}
|
|
||||||
end
|
|
||||||
|
|
||||||
Offset(name, n) = Offset(name, n, nothing)
|
|
||||||
|
|
||||||
Base.:-(o::Offset) = Offset(o.name, -o.n, o.default)
|
|
||||||
|
|
||||||
function liftloops(ex)
|
|
||||||
ex = DataFlow.normedges(ex)
|
|
||||||
decls = Dict()
|
|
||||||
ex = MacroTools.postwalk(ex) do ex
|
|
||||||
@capture(ex, x_{n_}) || return ex
|
|
||||||
haskey(decls, (x,n)) && return namify(decls[(x,n)])
|
|
||||||
@gensym edge
|
|
||||||
decls[(x,n)] = :($edge = $(Offset(x,n))($x))
|
|
||||||
edge
|
|
||||||
end
|
|
||||||
prepend!(ex.args, collect(values(decls)))
|
|
||||||
ex
|
|
||||||
end
|
|
||||||
|
|
||||||
function hasloops(model)
|
|
||||||
g = graph(model)
|
|
||||||
g == nothing && return false
|
|
||||||
iscyclic(g) && return true
|
|
||||||
result = false
|
|
||||||
map(m -> hasloops(m) && (result = true), g)
|
|
||||||
return result
|
|
||||||
end
|
|
||||||
|
|
||||||
function atomise(model)
|
|
||||||
postwalk(graph(model)) do v
|
|
||||||
hasloops(value(v)) || return v
|
|
||||||
spliceinputs(atomise(value(v)), inputs(v)...)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function collect_state(v::IVertex)
|
|
||||||
state = typeof(v)[]
|
|
||||||
offset = Int[]
|
|
||||||
default = []
|
|
||||||
prewalk!(v) do v
|
|
||||||
value(v) isa Offset || return v
|
|
||||||
if (i = findfirst(state, v[1])) == 0
|
|
||||||
push!(state, v[1])
|
|
||||||
push!(offset, max(0, -value(v).n))
|
|
||||||
push!(default, get(value(v).default))
|
|
||||||
else
|
|
||||||
offset[i] = max(offset[i], -value(v).n)
|
|
||||||
end
|
|
||||||
v
|
|
||||||
end
|
|
||||||
return state, offset, default
|
|
||||||
end
|
|
||||||
|
|
||||||
hiddeninput(n, t) = vertex(Split(t), inputnode(n))
|
|
||||||
|
|
||||||
# TODO: nicer way to do this.
|
|
||||||
create_steps(v::IVertex, n) = [bumpinputs(spliceinputs(v, [hiddeninput(n, t) for n = 1:graphinputs(v)]...)) for t = 1:n]
|
|
||||||
|
|
||||||
function getvar(n, step, steps, offset, default)
|
|
||||||
if step < 1
|
|
||||||
hiddeninput(1, sum(offset[1:n-1]) + 1 - step)
|
|
||||||
elseif step ∉ 1:length(steps)
|
|
||||||
constant(default[n])
|
|
||||||
else
|
|
||||||
steps[step][1,n]
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function stateout(steps, offset, default)
|
|
||||||
outs = []
|
|
||||||
defaults = []
|
|
||||||
for i = 1:length(offset), j = 1:offset[i]
|
|
||||||
push!(outs, getvar(i, length(steps)-j+1, steps, offset, default))
|
|
||||||
push!(defaults, default[i])
|
|
||||||
end
|
|
||||||
group(outs...), defaults
|
|
||||||
end
|
|
||||||
|
|
||||||
# Input: (hidden1, hidden2, ...), (x1, x2, ...)
|
|
||||||
# Output: (hidden1, hidden2, ...), (y1, y2, ...)
|
|
||||||
# TODO: make sure there's a reasonable order for hidden states
|
|
||||||
|
|
||||||
function unrollgraph(v::IVertex, n)
|
|
||||||
state, offset, default = collect_state(v)
|
|
||||||
v = group(group(state...), v)
|
|
||||||
steps = create_steps(v, n)
|
|
||||||
for i = 1:n
|
|
||||||
vars = inputs(steps[i][1])
|
|
||||||
postwalk!(steps[i]) do v
|
|
||||||
value(v) isa Offset || return v
|
|
||||||
varid = findfirst(vars,v[1])
|
|
||||||
getvar(varid, value(v).n + i, steps, offset, default)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
out = group(map(x->x[2], steps)...)
|
|
||||||
state, defaults = stateout(steps, offset, default)
|
|
||||||
group(state,out), defaults
|
|
||||||
end
|
|
||||||
|
|
||||||
unrollgraph(m, n; kws...) = unrollgraph(atomise(m), n; kws...)
|
|
||||||
|
|
||||||
function unroll(model, n)
|
|
||||||
graph, state = unrollgraph(model, n)
|
|
||||||
SeqModel(Stateful(Capacitor(graph), state), n)
|
|
||||||
end
|
|
||||||
|
|
||||||
function stateless(s::Stateful)
|
|
||||||
v = graph(s.model)
|
|
||||||
v = spliceinputs(v, group(constant.(s.states)...),
|
|
||||||
[inputnode(i) for i = 1:graphinputs(v)-1]...)
|
|
||||||
Capacitor(v[2])
|
|
||||||
end
|
|
||||||
|
|
||||||
stateless(s::SeqModel) = SeqModel(stateless(s.model), s.steps)
|
|
||||||
|
|
||||||
function unseqin(v::IVertex)
|
|
||||||
prewalk(v) do v
|
|
||||||
# TODO: inputidx function
|
|
||||||
isa(value(v), Split) && DataFlow.isinput(v[1]) && value(v[1]).n > 1 ? v[1] : v
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
unseqout(v::IVertex) = group(v[1], v[2][1])
|
|
||||||
|
|
||||||
unseq(graph) = unseqout(unseqin(graph))
|
|
||||||
|
|
||||||
function unroll1(model)
|
|
||||||
graph, state = unrollgraph(model, 1)
|
|
||||||
Stateful(Capacitor(unseq(graph)), state)
|
|
||||||
end
|
|
||||||
|
|
||||||
flip(model) = Capacitor(map(x -> x isa Offset ? -x : x, atomise(model)))
|
|
@ -22,13 +22,11 @@ end
|
|||||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
|
||||||
@forward Chain.layers Base.start, Base.next, Base.done
|
@forward Chain.layers Base.start, Base.next, Base.done
|
||||||
|
|
||||||
Optimise.children(c::Chain) = c.layers
|
children(c::Chain) = c.layers
|
||||||
|
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||||
|
|
||||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||||
|
|
||||||
Compiler.graph(s::Chain) =
|
|
||||||
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
|
|
||||||
|
|
||||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||||
|
|
||||||
function Base.show(io::IO, c::Chain)
|
function Base.show(io::IO, c::Chain)
|
||||||
@ -56,9 +54,12 @@ end
|
|||||||
Dense(in::Integer, out::Integer, σ = identity; init = initn) =
|
Dense(in::Integer, out::Integer, σ = identity; init = initn) =
|
||||||
Dense(σ, param(init(out, in)), param(init(out)))
|
Dense(σ, param(init(out, in)), param(init(out)))
|
||||||
|
|
||||||
Optimise.children(d::Dense) = (d.W, d.b)
|
treelike(Dense)
|
||||||
|
|
||||||
(a::Dense)(x) = a.σ.(a.W*x .+ a.b)
|
function (a::Dense)(x)
|
||||||
|
W, b, σ = a.W, a.b, a.σ
|
||||||
|
σ.(W*x .+ b)
|
||||||
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, l::Dense)
|
function Base.show(io::IO, l::Dense)
|
||||||
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
|
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
|
||||||
|
@ -16,7 +16,7 @@ function (m::Recur)(xs...)
|
|||||||
return y
|
return y
|
||||||
end
|
end
|
||||||
|
|
||||||
Optimise.children(m::Recur) = (m.cell,)
|
treelike(Recur)
|
||||||
|
|
||||||
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ _truncate(x::AbstractArray) = x
|
|||||||
_truncate(x::TrackedArray) = x.data
|
_truncate(x::TrackedArray) = x.data
|
||||||
_truncate(x::Tuple) = _truncate.(x)
|
_truncate(x::Tuple) = _truncate.(x)
|
||||||
|
|
||||||
truncate!(m) = foreach(truncate!, Optimise.children(m))
|
truncate!(m) = foreach(truncate!, children(m))
|
||||||
truncate!(m::Recur) = (m.state = _truncate(m.state))
|
truncate!(m::Recur) = (m.state = _truncate(m.state))
|
||||||
|
|
||||||
# Vanilla RNN
|
# Vanilla RNN
|
||||||
@ -44,7 +44,7 @@ end
|
|||||||
|
|
||||||
hidden(m::RNNCell) = m.h
|
hidden(m::RNNCell) = m.h
|
||||||
|
|
||||||
Optimise.children(m::RNNCell) = (m.d, m.h)
|
treelike(RNNCell)
|
||||||
|
|
||||||
function Base.show(io::IO, m::RNNCell)
|
function Base.show(io::IO, m::RNNCell)
|
||||||
print(io, "RNNCell(", m.d, ")")
|
print(io, "RNNCell(", m.d, ")")
|
||||||
@ -63,9 +63,9 @@ struct LSTMCell{D1,D2,V}
|
|||||||
end
|
end
|
||||||
|
|
||||||
function LSTMCell(in, out; init = initn)
|
function LSTMCell(in, out; init = initn)
|
||||||
cell = LSTMCell([Dense(in+out, out, σ, init = initn) for _ = 1:3]...,
|
cell = LSTMCell([Dense(in+out, out, σ, init = init) for _ = 1:3]...,
|
||||||
Dense(in+out, out, tanh, init = initn),
|
Dense(in+out, out, tanh, init = init),
|
||||||
param(initn(out)), param(initn(out)))
|
param(init(out)), param(init(out)))
|
||||||
cell.forget.b.data .= 1
|
cell.forget.b.data .= 1
|
||||||
return cell
|
return cell
|
||||||
end
|
end
|
||||||
@ -82,8 +82,7 @@ end
|
|||||||
|
|
||||||
hidden(m::LSTMCell) = (m.h, m.c)
|
hidden(m::LSTMCell) = (m.h, m.c)
|
||||||
|
|
||||||
Optimise.children(m::LSTMCell) =
|
treelike(LSTMCell)
|
||||||
(m.forget, m.input, m.output, m.cell, m.h, m.c)
|
|
||||||
|
|
||||||
Base.show(io::IO, m::LSTMCell) =
|
Base.show(io::IO, m::LSTMCell) =
|
||||||
print(io, "LSTMCell(",
|
print(io, "LSTMCell(",
|
||||||
|
@ -9,11 +9,12 @@ Base.getindex(xs::OneHotVector, i::Integer) = i == xs.ix
|
|||||||
|
|
||||||
Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix]
|
Base.:*(A::AbstractMatrix, b::OneHotVector) = A[:, b.ix]
|
||||||
|
|
||||||
struct OneHotMatrix <: AbstractMatrix{Bool}
|
struct OneHotMatrix{A<:AbstractVector{OneHotVector}} <: AbstractMatrix{Bool}
|
||||||
data::Vector{OneHotVector}
|
height::Int
|
||||||
|
data::A
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.size(xs::OneHotMatrix) = (Int64(length(xs.data[1])),length(xs.data))
|
Base.size(xs::OneHotMatrix) = (Int64(xs.height),length(xs.data))
|
||||||
|
|
||||||
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
|
Base.getindex(xs::OneHotMatrix, i::Int, j::Int) = xs.data[j][i]
|
||||||
|
|
||||||
@ -21,8 +22,18 @@ Base.:*(A::AbstractMatrix, B::OneHotMatrix) = A[:, map(x->x.ix, B.data)]
|
|||||||
|
|
||||||
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
|
Base.hcat(x::OneHotVector, xs::OneHotVector...) = OneHotMatrix([x, xs...])
|
||||||
|
|
||||||
|
import NNlib.adapt
|
||||||
|
|
||||||
|
adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data))
|
||||||
|
|
||||||
|
@require CuArrays begin
|
||||||
|
import CuArrays: CuArray, cudaconvert
|
||||||
|
Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray
|
||||||
|
cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data))
|
||||||
|
end
|
||||||
|
|
||||||
onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels))
|
onehot(l, labels) = OneHotVector(findfirst(labels, l), length(labels))
|
||||||
onehotbatch(ls, labels) = OneHotMatrix([onehot(l, labels) for l in ls])
|
onehotbatch(ls, labels) = OneHotMatrix(length(labels), [onehot(l, labels) for l in ls])
|
||||||
|
|
||||||
argmax(y::AbstractVector, labels = 1:length(y)) =
|
argmax(y::AbstractVector, labels = 1:length(y)) =
|
||||||
labels[findfirst(y, maximum(y))]
|
labels[findfirst(y, maximum(y))]
|
||||||
|
@ -3,15 +3,19 @@ module Optimise
|
|||||||
export update!, params, train!,
|
export update!, params, train!,
|
||||||
SGD
|
SGD
|
||||||
|
|
||||||
include("params.jl")
|
struct Param{T}
|
||||||
|
x::T
|
||||||
|
Δ::T
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
|
||||||
|
|
||||||
include("optimisers.jl")
|
include("optimisers.jl")
|
||||||
include("interface.jl")
|
include("interface.jl")
|
||||||
include("train.jl")
|
include("train.jl")
|
||||||
|
|
||||||
using Flux.Tracker: TrackedArray
|
using Flux.Tracker: TrackedArray
|
||||||
|
|
||||||
params(ps, p::TrackedArray) = push!(ps, p)
|
|
||||||
|
|
||||||
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[])
|
Base.convert(::Type{Param}, x::TrackedArray) = Param(x.data, x.grad[])
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -10,3 +10,9 @@ function optimiser(ps, fs...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
SGD(ps, η = 1) = optimiser(ps, p -> descent(p, η))
|
SGD(ps, η = 1) = optimiser(ps, p -> descent(p, η))
|
||||||
|
ADAM(ps, η = 0.001, β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0.0) = optimiser(ps, p -> adam(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||||
|
Momentum(ps,ρ, decay = 0.0) = optimiser(ps, p -> momentum(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||||
|
Nesterov(ps,ρ, decay = 0.0) = optimiser(ps, p -> nesterov(p, ρ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||||
|
RMSProp(ps, η = 0.001, ρ = 0.9, ϵ = 1e-8, decay = 0.0) = optimiser(ps, p -> rmsprop(p; η = η, ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||||
|
ADAGrad(ps, η = 0.01, ϵ = 1e-8, decay = 0.0) = optimiser(ps, p -> adagrad(p; η = η, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||||
|
ADADelta(ps, η = 0.01, ρ = 0.95, ϵ = 1e-8, decay = 0.0) = optimiser(ps, p -> adadelta(p; ρ = ρ, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))
|
||||||
|
@ -1,18 +0,0 @@
|
|||||||
using DataFlow: OSet
|
|
||||||
|
|
||||||
children(x) = ()
|
|
||||||
|
|
||||||
params(ps, m) = foreach(m -> params(ps, m), children(m))
|
|
||||||
|
|
||||||
function params(m)
|
|
||||||
ps = OSet()
|
|
||||||
params(ps, m)
|
|
||||||
return collect(ps)
|
|
||||||
end
|
|
||||||
|
|
||||||
struct Param{T}
|
|
||||||
x::T
|
|
||||||
Δ::T
|
|
||||||
end
|
|
||||||
|
|
||||||
convert(::Type{Param}, x::AbstractArray) = Param(x, zeros(x))
|
|
@ -8,8 +8,8 @@ function train!(m, data, opt; cb = () -> ())
|
|||||||
cb = tocb(cb)
|
cb = tocb(cb)
|
||||||
@progress for x in data
|
@progress for x in data
|
||||||
l = m(x...)
|
l = m(x...)
|
||||||
isinf(l.data[]) && error("Inf")
|
isinf(l.data[]) && error("Loss is Inf")
|
||||||
isnan(l.data[]) && error("NaN")
|
isnan(l.data[]) && error("Loss is NaN")
|
||||||
back!(l)
|
back!(l)
|
||||||
opt()
|
opt()
|
||||||
cb()
|
cb()
|
||||||
|
@ -71,11 +71,10 @@ include("back.jl")
|
|||||||
include("lib.jl")
|
include("lib.jl")
|
||||||
include("numeric.jl")
|
include("numeric.jl")
|
||||||
|
|
||||||
using Requires
|
import NNlib.adapt
|
||||||
|
|
||||||
@require CuArrays begin
|
adapt(T, xs::TrackedArray) =
|
||||||
import CuArrays: cu
|
TrackedArray(xs.f, adapt(T, xs.data),
|
||||||
cu(xs::TrackedArray) = TrackedArray(xs.f, cu(xs.data), RefValue(cu(grad(xs))))
|
RefValue(adapt(T, grad(xs))))
|
||||||
end
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
25
src/tree.jl
Normal file
25
src/tree.jl
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
children(x) = ()
|
||||||
|
mapchildren(f, x) = x
|
||||||
|
|
||||||
|
function treelike(T, fs = fieldnames(T))
|
||||||
|
@eval begin
|
||||||
|
children(x::$T) = ($([:(x.$f) for f in fs]...),)
|
||||||
|
mapchildren(f, x::$T) = $T(f.(children(x))...)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# TODO: prewalk/postwalk with correct caching
|
||||||
|
# This is only correct in general for idempotent functions
|
||||||
|
|
||||||
|
mapparams(f, x::AbstractArray) = f(x)
|
||||||
|
mapparams(f, x) = mapchildren(x -> mapparams(f, x), x)
|
||||||
|
|
||||||
|
forparams(f, x) = (mapparams(x -> (f(x); x), x); return)
|
||||||
|
|
||||||
|
using DataFlow: OSet
|
||||||
|
|
||||||
|
function params(m)
|
||||||
|
ps = OSet()
|
||||||
|
forparams(p -> push!(ps, p), m)
|
||||||
|
return collect(ps)
|
||||||
|
end
|
@ -1,86 +0,0 @@
|
|||||||
using DataFlow, MacroTools
|
|
||||||
using Flux: stack, unsqueeze
|
|
||||||
using Flux.Compiler: @net, graph
|
|
||||||
using DataFlow: Line, Frame
|
|
||||||
|
|
||||||
@net type Affine
|
|
||||||
W
|
|
||||||
b
|
|
||||||
x -> x*W .+ b
|
|
||||||
end
|
|
||||||
|
|
||||||
Affine(in::Integer, out::Integer; init = Flux.initn) =
|
|
||||||
Affine(init(in, out), init(1, out))
|
|
||||||
|
|
||||||
@net type TLP
|
|
||||||
first
|
|
||||||
second
|
|
||||||
function (x)
|
|
||||||
l1 = σ.(first(x))
|
|
||||||
l2 = softmax(second(l1))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
@net type Recurrent
|
|
||||||
Wxy; Wyy; by
|
|
||||||
y
|
|
||||||
function (x)
|
|
||||||
y = tanh.( x * Wxy .+ y{-1} * Wyy .+ by )
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
Recurrent(in, out; init = Flux.initn) =
|
|
||||||
Recurrent(init((in, out)), init((out, out)), init(1, out), init(1, out))
|
|
||||||
|
|
||||||
syntax(v::Vertex) = prettify(DataFlow.syntax(v))
|
|
||||||
syntax(x) = syntax(graph(x))
|
|
||||||
|
|
||||||
@testset "Compiler" begin
|
|
||||||
|
|
||||||
xs = randn(1, 10)
|
|
||||||
d = Affine(10, 20)
|
|
||||||
|
|
||||||
@test d(xs) ≈ (xs*d.W + d.b)
|
|
||||||
|
|
||||||
d1 = @net x -> x * d.W + d.b
|
|
||||||
|
|
||||||
let
|
|
||||||
@capture(syntax(d), _Frame(_Line((+).(x_[1] * W_, b_))))
|
|
||||||
@test isa(x, DataFlow.Input) && W isa Array && b isa Array
|
|
||||||
end
|
|
||||||
|
|
||||||
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
|
||||||
tlp = TLP(a1, a2)
|
|
||||||
@test tlp(xs) ≈ softmax(a2(σ.(a1(xs))))
|
|
||||||
@test Flux.Compiler.interpmodel(tlp, xs) ≈ softmax(a2(σ.(a1(xs))))
|
|
||||||
end
|
|
||||||
|
|
||||||
let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
|
||||||
e = try
|
|
||||||
Flux.Compiler.interpmodel(tlp, rand(1, 10))
|
|
||||||
catch e
|
|
||||||
e
|
|
||||||
end
|
|
||||||
@test e.trace[end].func == :TLP
|
|
||||||
@test e.trace[end-1].func == Symbol("Affine")
|
|
||||||
end
|
|
||||||
|
|
||||||
function apply(model, xs, state)
|
|
||||||
ys = similar(xs, 0)
|
|
||||||
for x in xs
|
|
||||||
state, y = model(state, x)
|
|
||||||
push!(ys, y)
|
|
||||||
end
|
|
||||||
state, ys
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "RNN unrolling" begin
|
|
||||||
r = Recurrent(10, 5)
|
|
||||||
xs = [rand(1, 10) for _ = 1:3]
|
|
||||||
_, ys = apply(Flux.Compiler.unroll1(r).model, xs, (r.y,))
|
|
||||||
@test ys[1] == tanh.(xs[1] * r.Wxy .+ r.y * r.Wyy .+ r.by)
|
|
||||||
ru = Flux.Compiler.unroll(r, 3)
|
|
||||||
ru(unsqueeze(stack(squeeze.(xs, 1), 1), 1))[1] == squeeze.(ys, 1)
|
|
||||||
end
|
|
||||||
|
|
||||||
end
|
|
@ -2,7 +2,6 @@ using Flux, Base.Test
|
|||||||
|
|
||||||
@testset "Flux" begin
|
@testset "Flux" begin
|
||||||
|
|
||||||
include("compiler.jl")
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("tracker.jl")
|
include("tracker.jl")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user