dense -> affine
This commit is contained in:
parent
6140448f17
commit
bdd05157e2
14
README.md
14
README.md
|
@ -19,9 +19,9 @@ We can describe simple models through a convenient interface:
|
||||||
```julia
|
```julia
|
||||||
m = Chain(
|
m = Chain(
|
||||||
Input(784),
|
Input(784),
|
||||||
Dense(128), relu,
|
Affine(128), relu,
|
||||||
Dense( 64), relu,
|
Affine( 64), relu,
|
||||||
Dense( 10), softmax)
|
Affine( 10), softmax)
|
||||||
```
|
```
|
||||||
|
|
||||||
Models are simple functions with state, so we can immediately see what the network does:
|
Models are simple functions with state, so we can immediately see what the network does:
|
||||||
|
@ -30,7 +30,7 @@ Models are simple functions with state, so we can immediately see what the netwo
|
||||||
m(randn(784)) #> [0.101, 0.101, 0.099, 0.100, ...]
|
m(randn(784)) #> [0.101, 0.101, 0.099, 0.100, ...]
|
||||||
```
|
```
|
||||||
|
|
||||||
What if we need a custom layer? Here's one equivalent to `Dense` above:
|
What if we need a custom layer? Here's one equivalent to `Affine` above:
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
# Simple Julia type with two fields – @net defines some extra methods like the
|
# Simple Julia type with two fields – @net defines some extra methods like the
|
||||||
|
@ -55,10 +55,10 @@ We can already insert this model into combining models like `Chain`. If you want
|
||||||
x -> σ(layer(x))
|
x -> σ(layer(x))
|
||||||
end
|
end
|
||||||
|
|
||||||
Perceptron(in, out) = Perceptron(Dense(in, out))
|
Perceptron(in, out) = Perceptron(Affine(in, out))
|
||||||
```
|
```
|
||||||
|
|
||||||
This defines a simple perceptron layer which we can use in the same way as `Dense` above. We can draw arbitrary graphs, including those with splits, combines or recurrences, in a fully declarative way *[this API is a WIP]*:
|
This defines a simple perceptron layer which we can use in the same way as `Affine` above. We can draw arbitrary graphs, including those with splits, combines or recurrences, in a fully declarative way *[this API is a WIP]*:
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
@net type SimpleRecurrent
|
@net type SimpleRecurrent
|
||||||
|
@ -82,7 +82,7 @@ end
|
||||||
end
|
end
|
||||||
```
|
```
|
||||||
|
|
||||||
Though further from the equations, this has the advantage of further reuse and customizability. For example, `layer` could be a simple `Dense(x, y)` as before or it could be a `Dropout(Dense(x, y))` in order to add dropout to the recurrent layer.
|
Though further from the equations, this has the advantage of further reuse and customizability. For example, `layer` could be a simple `Affine(x, y)` as before or it could be a `Dropout(Affine(x, y))` in order to add dropout to the recurrent layer.
|
||||||
|
|
||||||
When it comes time to train the model, we have a number of options for tweaking its implementation, like the backend used or unrolling settings. In Flux this is as simple as calling some functions on the original model:
|
When it comes time to train the model, we have a number of options for tweaking its implementation, like the backend used or unrolling settings. In Flux this is as simple as calling some functions on the original model:
|
||||||
|
|
||||||
|
|
|
@ -6,9 +6,9 @@ test = data[50_001:60_000]
|
||||||
|
|
||||||
m = Chain(
|
m = Chain(
|
||||||
Input(784),
|
Input(784),
|
||||||
Dense(128), relu,
|
Affine(128), relu,
|
||||||
Dense( 64), relu,
|
Affine( 64), relu,
|
||||||
Dense( 10), softmax)
|
Affine( 10), softmax)
|
||||||
|
|
||||||
# Convert to TensorFlow
|
# Convert to TensorFlow
|
||||||
model = tf(m)
|
model = tf(m)
|
||||||
|
|
|
@ -14,7 +14,7 @@ model = Chain(
|
||||||
Input(N),
|
Input(N),
|
||||||
LSTM(N, 256),
|
LSTM(N, 256),
|
||||||
LSTM(256, 256),
|
LSTM(256, 256),
|
||||||
Dense(256, N),
|
Affine(256, N),
|
||||||
softmax)
|
softmax)
|
||||||
|
|
||||||
m = tf(unroll(model, 50));
|
m = tf(unroll(model, 50));
|
||||||
|
|
|
@ -28,8 +28,8 @@ conv2 = Chain(
|
||||||
|
|
||||||
lenet = Chain(
|
lenet = Chain(
|
||||||
conv1, conv2, flatten,
|
conv1, conv2, flatten,
|
||||||
Dense(500), tanh,
|
Affine(500), tanh,
|
||||||
Dense(10), softmax)
|
Affine(10), softmax)
|
||||||
|
|
||||||
#--------------------------------------------------------------------------------
|
#--------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ include("compiler/diff.jl")
|
||||||
include("compiler/code.jl")
|
include("compiler/code.jl")
|
||||||
include("compiler/loops.jl")
|
include("compiler/loops.jl")
|
||||||
|
|
||||||
include("layers/dense.jl")
|
include("layers/Affine.jl")
|
||||||
include("layers/recurrent.jl")
|
include("layers/recurrent.jl")
|
||||||
include("layers/shape.jl")
|
include("layers/shape.jl")
|
||||||
include("layers/chain.jl")
|
include("layers/chain.jl")
|
||||||
|
|
|
@ -54,8 +54,8 @@ end
|
||||||
|
|
||||||
hiddeninput(n) = vertex(Split(n), inputnode(1))
|
hiddeninput(n) = vertex(Split(n), inputnode(1))
|
||||||
|
|
||||||
function create_steps(v::IVertex, n)
|
function create_steps(v::IVertex, n; seq = true)
|
||||||
[bumpinputs(spliceinputs(v, hiddeninput(i))) for i = 1:n]
|
[bumpinputs(seq ? spliceinputs(v, hiddeninput(i)) : v) for i = 1:n]
|
||||||
end
|
end
|
||||||
|
|
||||||
function getvar(n, step, steps, offset, default)
|
function getvar(n, step, steps, offset, default)
|
||||||
|
@ -78,10 +78,10 @@ function stateout(steps, offset, default)
|
||||||
group(outs...), defaults
|
group(outs...), defaults
|
||||||
end
|
end
|
||||||
|
|
||||||
function unrollgraph(v::IVertex, n)
|
function unrollgraph(v::IVertex, n; seq = true)
|
||||||
state, offset, default = collect_state(v)
|
state, offset, default = collect_state(v)
|
||||||
v = group(group(state...), v)
|
v = group(group(state...), v)
|
||||||
steps = create_steps(v, n)
|
steps = create_steps(v, n, seq = seq)
|
||||||
for i = 1:n
|
for i = 1:n
|
||||||
vars = inputs(steps[i][1])
|
vars = inputs(steps[i][1])
|
||||||
postwalk!(steps[i]) do v
|
postwalk!(steps[i]) do v
|
||||||
|
@ -94,7 +94,7 @@ function unrollgraph(v::IVertex, n)
|
||||||
group(state,group(map(x->x[2], steps)...)), map(Flux.state, defaults)
|
group(state,group(map(x->x[2], steps)...)), map(Flux.state, defaults)
|
||||||
end
|
end
|
||||||
|
|
||||||
unrollgraph(m, n) = unrollgraph(atomise(m), n)
|
unrollgraph(m, n; seq = true) = unrollgraph(atomise(m), n; seq = seq)
|
||||||
|
|
||||||
type Unrolled <: Model
|
type Unrolled <: Model
|
||||||
model
|
model
|
||||||
|
@ -105,6 +105,6 @@ end
|
||||||
|
|
||||||
graph(u::Unrolled) = u.graph
|
graph(u::Unrolled) = u.graph
|
||||||
|
|
||||||
unroll(model, n) = Unrolled(model, unrollgraph(model, n)..., n)
|
unroll(model, n; seq = true) = Unrolled(model, unrollgraph(model, n; seq = seq)..., n)
|
||||||
|
|
||||||
flip(model) = Capacitor(map(x -> isa(x, Offset) ? -x : x, atomise(model)))
|
flip(model) = Capacitor(map(x -> isa(x, Offset) ? -x : x, atomise(model)))
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
export Dense
|
export Affine
|
||||||
|
|
||||||
# TODO: type hints for parameters
|
# TODO: type hints for parameters
|
||||||
|
|
||||||
@net type Dense
|
@net type Affine
|
||||||
W
|
W
|
||||||
b
|
b
|
||||||
x -> x*W + b
|
x -> x*W + b
|
||||||
end
|
end
|
||||||
|
|
||||||
Dense(in::Integer, out::Integer; init = initn) =
|
Affine(in::Integer, out::Integer; init = initn) =
|
||||||
Dense(init(in, out), init(1, out))
|
Affine(init(in, out), init(1, out))
|
||||||
|
|
||||||
@net type Sigmoid
|
@net type Sigmoid
|
||||||
layer::Model
|
layer::Model
|
||||||
|
@ -17,4 +17,4 @@ Dense(in::Integer, out::Integer; init = initn) =
|
||||||
end
|
end
|
||||||
|
|
||||||
Sigmoid(in::Integer, out::Integer; init = randn) =
|
Sigmoid(in::Integer, out::Integer; init = randn) =
|
||||||
Sigmoid(Dense(in, out, init = init))
|
Sigmoid(Affine(in, out, init = init))
|
||||||
|
|
|
@ -42,6 +42,6 @@ shape(i::Input, _) = i.dims
|
||||||
|
|
||||||
# Implementation for bundled layers
|
# Implementation for bundled layers
|
||||||
|
|
||||||
shape(d::Dense, _) = length(state(d.b)) # TODO: could perhaps infer this
|
shape(d::Affine, _) = length(state(d.b)) # TODO: could perhaps infer this
|
||||||
|
|
||||||
Dense(out::Integer) = Init(in::Integer -> Dense(in, out))
|
Affine(out::Integer) = Init(in::Integer -> Affine(in, out))
|
||||||
|
|
Loading…
Reference in New Issue