Flux.jl/src/layers/basic.jl
2018-06-26 14:30:46 +01:00

120 lines
2.6 KiB
Julia
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Chain(layers...)
Chain multiple layers / functions together, so that they are called in sequence
on a given input.
```julia
m = Chain(x -> x^2, x -> x+1)
m(5) == 26
m = Chain(Dense(10, 5), Dense(5, 2))
x = rand(10)
m(x) == m[2](m[1](x))
```
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
`m[1:3](x)` will calculate the output of the first three layers.
"""
type Chain
layers::Vector{Any}
Chain(xs...) = new([xs...])
end
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
@forward Chain.layers Base.start, Base.next, Base.done
children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
adapt(T, c::Chain) = Chain(map(x -> adapt(T, x), c.layers)...)
(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers)
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
function Base.show(io::IO, c::Chain)
print(io, "Chain(")
join(io, c.layers, ", ")
print(io, ")")
end
# Seem to need this for `accumulate`; try removing on 0.7
Base.rcum_promote_type(op, ::Type, ::Type{Any}) = Any
activations(c::Chain, x) = accumulate((x, m) -> m(x), x, c.layers)
"""
Dense(in::Integer, out::Integer, σ = identity)
Creates a traditional `Dense` layer with parameters `W` and `b`.
y = σ.(W * x .+ b)
The input `x` must be a vector of length `in`, or a batch of vectors represented
as an `in × N` matrix. The out `y` will be a vector or batch of length `out`.
```julia
julia> d = Dense(5, 2)
Dense(5, 2)
julia> d(rand(5))
Tracked 2-element Array{Float64,1}:
0.00257447
-0.00449443
```
"""
struct Dense{F,S,T}
W::S
b::T
σ::F
end
Dense(W, b) = Dense(W, b, identity)
function Dense(in::Integer, out::Integer, σ = identity;
initW = glorot_uniform, initb = zeros)
return Dense(param(initW(out, in)), param(initb(out)), σ)
end
treelike(Dense)
function (a::Dense)(x)
W, b, σ = a.W, a.b, a.σ
@fix σ.(W*x .+ b)
end
function Base.show(io::IO, l::Dense)
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
"""
Diagonal(in::Integer)
Creates an element-wise linear transformation layer with learnable
vectors `α` and `β`:
y = α .* x .+ β
The input `x` must be a array where `size(x, 1) == in`.
"""
struct Diagonal{T}
α::T
β::T
end
Diagonal(in::Integer; initα = ones, initβ = zeros) =
Diagonal(param(initα(in)), param(initβ(in)))
treelike(Diagonal)
function (a::Diagonal)(x)
α, β = a.α, a.β
α.*x .+ β
end
function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")")
end