commit
4cba46c293
@ -2,6 +2,7 @@ module Flux
|
|||||||
|
|
||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
|
||||||
|
using Base: tail
|
||||||
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
using MacroTools, Juno, Requires, Reexport, Statistics, Random
|
||||||
using MacroTools: @forward
|
using MacroTools: @forward
|
||||||
|
|
||||||
|
@ -16,18 +16,21 @@ m(x) == m[2](m[1](x))
|
|||||||
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
|
`Chain` also supports indexing and slicing, e.g. `m[2]` or `m[1:end-1]`.
|
||||||
`m[1:3](x)` will calculate the output of the first three layers.
|
`m[1:3](x)` will calculate the output of the first three layers.
|
||||||
"""
|
"""
|
||||||
struct Chain
|
struct Chain{T<:Tuple}
|
||||||
layers::Vector{Any}
|
layers::T
|
||||||
Chain(xs...) = new([xs...])
|
Chain(xs...) = new{typeof(xs)}(xs)
|
||||||
end
|
end
|
||||||
|
|
||||||
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex, Base.push!
|
@forward Chain.layers Base.getindex, Base.first, Base.last, Base.lastindex
|
||||||
@forward Chain.layers Base.iterate
|
@forward Chain.layers Base.iterate
|
||||||
|
|
||||||
children(c::Chain) = c.layers
|
children(c::Chain) = c.layers
|
||||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||||
|
|
||||||
(c::Chain)(x) = foldl((x, m) -> m(x), c.layers; init = x)
|
applychain(::Tuple{}, x) = x
|
||||||
|
applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
|
||||||
|
|
||||||
|
(c::Chain)(x) = applychain(c.layers, x)
|
||||||
|
|
||||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user