Flux.jl/src/layers/basic.jl

226 lines
5.5 KiB
Julia
Raw Normal View History

2017-09-08 21:52:41 +00:00
"""
Chain(layers...)
2017-08-19 19:52:29 +00:00
2017-09-08 21:52:41 +00:00
Chain multiple layers / functions together, so that they are called in sequence
on a given input.
2017-10-18 14:44:06 +00:00
```julia
m = Chain(x -> x^2, x -> x+1)
m(5) == 26
2017-09-08 21:52:41 +00:00
2017-10-18 14:44:06 +00:00
m = Chain(Dense(10, 5), Dense(5, 2))
x = rand(10)
m(x) == m[2](m[1](x))
```
2017-09-08 21:52:41 +00:00
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
2017-09-10 00:02:48 +00:00
`m[1:3](x)` will calculate the output of the first three layers.
2017-09-08 21:52:41 +00:00
"""
2018-11-16 12:22:15 +00:00
struct Chain{T<:Tuple}
layers::T
Chain(xs...) = new{typeof(xs)}(xs)
2016-08-25 21:49:21 +00:00
end
2019-01-16 14:51:37 +00:00
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex
2016-08-25 21:49:21 +00:00
2019-09-19 14:22:11 +00:00
functor(c::Chain) = c.layers, ls -> Chain(ls...)
2017-08-22 16:13:03 +00:00
2018-11-16 12:22:15 +00:00
applychain(::Tuple{}, x) = x
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
(c::Chain)(x) = applychain(c.layers, x)
2017-06-12 11:39:34 +00:00
2017-02-28 16:42:48 +00:00
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
2017-08-19 19:52:29 +00:00
2017-08-21 16:20:09 +00:00
function Base.show(io::IO, c::Chain)
print(io, "Chain(")
join(io, c.layers, ", ")
print(io, ")")
end
2019-03-28 09:07:04 +00:00
# This is a temporary and naive implementation
# it might be replaced in the future for better performance
# see issue https://github.com/FluxML/Flux.jl/issues/702
2019-04-05 10:44:00 +00:00
# Johnny Chen -- @johnnychen94
2019-03-28 09:07:04 +00:00
"""
2019-04-05 10:44:00 +00:00
activations(c::Chain, input)
Calculate the forward results of each layers in Chain `c` with `input` as model input.
2019-03-28 09:07:04 +00:00
"""
2019-04-05 10:44:00 +00:00
function activations(c::Chain, input)
rst = []
for l in c
x = get(rst, length(rst), input)
push!(rst, l(x))
2019-03-28 14:40:24 +00:00
end
return rst
2019-03-28 09:07:04 +00:00
end
2018-06-26 13:30:46 +00:00
2017-09-08 21:52:41 +00:00
"""
Dense(in::Integer, out::Integer, σ = identity)
Creates a traditional `Dense` layer with parameters `W` and `b`.
2017-08-19 19:52:29 +00:00
2017-09-08 21:52:41 +00:00
y = σ.(W * x .+ b)
2017-09-09 23:58:32 +00:00
The input `x` must be a vector of length `in`, or a batch of vectors represented
2017-10-18 11:48:58 +00:00
as an `in × N` matrix. The out `y` will be a vector or batch of length `out`.
2017-10-18 11:47:45 +00:00
2017-10-18 14:44:06 +00:00
```julia
julia> d = Dense(5, 2)
Dense(5, 2)
2017-10-18 11:47:45 +00:00
2017-10-18 14:44:06 +00:00
julia> d(rand(5))
Tracked 2-element Array{Float64,1}:
0.00257447
-0.00449443
```
2017-09-08 21:52:41 +00:00
"""
2017-09-02 20:50:11 +00:00
struct Dense{F,S,T}
2017-08-19 19:52:29 +00:00
W::S
b::T
2018-02-15 20:52:29 +00:00
σ::F
2017-08-19 19:52:29 +00:00
end
2018-02-15 20:52:29 +00:00
Dense(W, b) = Dense(W, b, identity)
function Dense(in::Integer, out::Integer, σ = identity;
initW = glorot_uniform, initb = zeros)
2019-03-08 12:13:58 +00:00
return Dense(initW(out, in), initb(out), σ)
end
2017-08-19 19:52:29 +00:00
2019-09-19 14:53:31 +00:00
@functor Dense
2017-08-22 16:13:03 +00:00
2018-08-23 13:34:11 +00:00
function (a::Dense)(x::AbstractArray)
2017-09-27 20:58:34 +00:00
W, b, σ = a.W, a.b, a.σ
2018-08-20 12:08:04 +00:00
σ.(W*x .+ b)
2017-09-27 20:58:34 +00:00
end
2017-08-21 16:20:09 +00:00
2017-09-02 20:50:11 +00:00
function Base.show(io::IO, l::Dense)
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
2017-08-21 16:20:09 +00:00
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
2017-10-10 20:33:37 +00:00
# Try to avoid hitting generic matmul in some simple cases
# Base's matmul is so slow that it's worth the extra conversion to hit BLAS
(a::Dense{<:Any,W})(x::AbstractArray{T}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
invoke(a, Tuple{AbstractArray}, x)
2019-08-09 12:53:11 +00:00
(a::Dense{<:Any,W})(x::AbstractArray{<:AbstractFloat}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} =
a(T.(x))
2017-10-10 20:33:37 +00:00
"""
2017-10-23 11:53:07 +00:00
Diagonal(in::Integer)
2017-10-10 20:33:37 +00:00
Creates an element-wise linear transformation layer with learnable
2017-11-21 16:04:04 +00:00
vectors `α` and `β`:
2017-10-10 20:33:37 +00:00
2017-11-21 16:04:04 +00:00
y = α .* x .+ β
2017-10-10 20:33:37 +00:00
2017-10-23 11:53:07 +00:00
The input `x` must be a array where `size(x, 1) == in`.
2017-10-10 20:33:37 +00:00
"""
2017-10-23 11:53:07 +00:00
struct Diagonal{T}
2017-10-10 20:33:37 +00:00
α::T
β::T
end
2018-07-17 15:13:55 +00:00
Diagonal(in::Integer; initα = ones, initβ = zeros) =
2019-03-08 12:13:58 +00:00
Diagonal(initα(in), initβ(in))
2017-10-10 20:33:37 +00:00
2019-09-19 14:53:31 +00:00
@functor Diagonal
2017-10-10 20:33:37 +00:00
2017-10-23 11:53:07 +00:00
function (a::Diagonal)(x)
2017-10-10 20:33:37 +00:00
α, β = a.α, a.β
α.*x .+ β
end
2017-10-23 11:53:07 +00:00
function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")")
2017-10-10 20:33:37 +00:00
end
2018-09-07 00:25:32 +00:00
2019-02-27 12:04:59 +00:00
"""
2019-03-06 18:22:46 +00:00
Maxout(over)
2019-02-27 12:04:59 +00:00
2019-03-06 18:22:46 +00:00
`Maxout` is a neural network layer, which has a number of internal layers,
2019-03-11 21:40:29 +00:00
which all have the same input, and the maxout returns the elementwise maximium
2019-02-27 12:04:59 +00:00
of the internal layers' outputs.
Maxout over linear dense layers satisfies the univeral approximation theorem.
Reference:
Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron Courville, and Yoshua Bengio.
2013. Maxout networks.
In Proceedings of the 30th International Conference on International Conference on Machine Learning - Volume 28 (ICML'13),
Sanjoy Dasgupta and David McAllester (Eds.), Vol. 28. JMLR.org III-1319-III-1327.
https://arxiv.org/pdf/1302.4389.pdf
"""
2019-03-06 18:22:46 +00:00
struct Maxout{FS<:Tuple}
2019-02-27 12:04:59 +00:00
over::FS
end
"""
2019-03-21 17:04:52 +00:00
Maxout(f, n_alts)
2019-02-27 12:04:59 +00:00
2019-03-06 18:22:46 +00:00
Constructs a Maxout layer over `n_alts` instances of the layer given by `f`.
2019-03-11 21:40:29 +00:00
The function takes no arguement and should return some callable layer.
Conventionally this is a linear dense layer.
2019-02-27 12:04:59 +00:00
2019-02-27 15:19:10 +00:00
For example the following example which
2019-03-11 21:40:29 +00:00
will construct a `Maxout` layer over 4 internal dense linear layers,
2019-02-27 12:04:59 +00:00
each identical in structure (784 inputs, 128 outputs).
```julia
insize = 784
2019-04-19 21:02:26 +00:00
outsize = 128
2019-03-11 21:40:29 +00:00
Maxout(()->Dense(insize, outsize), 4)
2019-02-27 12:04:59 +00:00
```
"""
2019-03-21 17:04:52 +00:00
function Maxout(f, n_alts)
over = Tuple(f() for _ in 1:n_alts)
2019-03-06 18:22:46 +00:00
return Maxout(over)
2019-02-27 12:04:59 +00:00
end
2019-09-19 14:53:31 +00:00
@functor Maxout
2019-03-25 16:02:46 +00:00
2019-03-06 18:22:46 +00:00
function (mo::Maxout)(input::AbstractArray)
2019-02-27 12:04:59 +00:00
mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over)
end
2019-06-10 12:54:18 +00:00
"""
2019-09-25 11:37:01 +00:00
SkipConnection(layers, connection)
2019-09-25 13:18:40 +00:00
Creates a Skip Connection, of a layer or `Chain` of consecutive layers
plus a shortcut connection. The connection function will combine the result of the layers
with the original input, to give the final output.
2019-09-25 13:18:40 +00:00
The simplest 'ResNet'-type connection is just `SkipConnection(layer, +)`,
and requires the output of the layers to be the same shape as the input.
Here is a more complicated example:
```
m = Conv((3,3), 4=>7, pad=(1,1))
x = ones(5,5,4,10);
size(m(x)) == (5, 5, 7, 10)
2019-09-25 13:18:40 +00:00
sm = SkipConnection(m, (mx, x) -> cat(mx, x, dims=3))
size(sm(x)) == (5, 5, 11, 10)
```
"""
2019-09-25 13:18:40 +00:00
function SkipConnection end
struct SkipConnection
layers
connection #user can pass arbitrary connections here, such as (a,b) -> a + b
end
2019-09-19 14:53:31 +00:00
@functor SkipConnection
function (skip::SkipConnection)(input)
skip.connection(skip.layers(input), input)
end
function Base.show(io::IO, b::SkipConnection)
2019-09-25 11:59:32 +00:00
print(io, "SkipConnection(", b.layers, ", ", b.connection, ")")
end