Flux.jl/src/layers/basic.jl

81 lines
1.7 KiB
Julia
Raw Normal View History

2017-09-08 21:52:41 +00:00
"""
Chain(layers...)
2017-08-19 19:52:29 +00:00
2017-09-08 21:52:41 +00:00
Chain multiple layers / functions together, so that they are called in sequence
on a given input.
2017-10-18 14:44:06 +00:00
```julia
m = Chain(x -> x^2, x -> x+1)
m(5) == 26
2017-09-08 21:52:41 +00:00
2017-10-18 14:44:06 +00:00
m = Chain(Dense(10, 5), Dense(5, 2))
x = rand(10)
m(x) == m[2](m[1](x))
```
2017-09-08 21:52:41 +00:00
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
2017-09-10 00:02:48 +00:00
`m[1:3](x)` will calculate the output of the first three layers.
2017-09-08 21:52:41 +00:00
"""
2017-06-05 15:08:23 +00:00
type Chain
2016-08-25 21:49:21 +00:00
layers::Vector{Any}
2017-03-17 16:34:51 +00:00
Chain(xs...) = new([xs...])
2016-08-25 21:49:21 +00:00
end
2017-03-17 16:34:51 +00:00
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.endof, Base.push!
2017-03-07 14:37:37 +00:00
@forward Chain.layers Base.start, Base.next, Base.done
2016-08-25 21:49:21 +00:00
2017-09-27 20:11:21 +00:00
children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
2017-08-22 16:13:03 +00:00
2016-08-25 21:49:21 +00:00
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
2017-06-12 11:39:34 +00:00
2017-02-28 16:42:48 +00:00
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
2017-08-19 19:52:29 +00:00
2017-08-21 16:20:09 +00:00
function Base.show(io::IO, c::Chain)
print(io, "Chain(")
join(io, c.layers, ", ")
print(io, ")")
end
2017-09-08 21:52:41 +00:00
"""
Dense(in::Integer, out::Integer, σ = identity)
Creates a traditional `Dense` layer with parameters `W` and `b`.
2017-08-19 19:52:29 +00:00
2017-09-08 21:52:41 +00:00
y = σ.(W * x .+ b)
2017-09-09 23:58:32 +00:00
The input `x` must be a vector of length `in`, or a batch of vectors represented
2017-10-18 11:48:58 +00:00
as an `in × N` matrix. The out `y` will be a vector or batch of length `out`.
2017-10-18 11:47:45 +00:00
2017-10-18 14:44:06 +00:00
```julia
julia> d = Dense(5, 2)
Dense(5, 2)
2017-10-18 11:47:45 +00:00
2017-10-18 14:44:06 +00:00
julia> d(rand(5))
Tracked 2-element Array{Float64,1}:
0.00257447
-0.00449443
```
2017-09-08 21:52:41 +00:00
"""
2017-09-02 20:50:11 +00:00
struct Dense{F,S,T}
2017-08-20 12:35:35 +00:00
σ::F
2017-08-19 19:52:29 +00:00
W::S
b::T
end
2017-09-02 20:50:11 +00:00
Dense(in::Integer, out::Integer, σ = identity; init = initn) =
2017-09-07 19:13:04 +00:00
Dense(σ, param(init(out, in)), param(init(out)))
2017-08-19 19:52:29 +00:00
2017-09-27 20:11:21 +00:00
treelike(Dense)
2017-08-22 16:13:03 +00:00
2017-09-27 20:58:34 +00:00
function (a::Dense)(x)
W, b, σ = a.W, a.b, a.σ
σ.(W*x .+ b)
end
2017-08-21 16:20:09 +00:00
2017-09-02 20:50:11 +00:00
function Base.show(io::IO, l::Dense)
print(io, "Dense(", size(l.W, 2), ", ", size(l.W, 1))
2017-08-21 16:20:09 +00:00
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end