Flux.jl/src/layers/chain.jl

37 lines
769 B
Julia
Raw Normal View History

2016-08-25 21:49:21 +00:00
export Chain
function inferchain(ms)
chain = []
sh = nothing
for m in ms
m = init(m, single(sh))
sh = shape(m, sh)
push!(chain, m)
end
return chain, sh
end
type Chain <: Model
layers::Vector{Any}
shape
function Chain(ms...)
ms, shape = inferchain(ms)
return new(ms, shape)
end
end
@forward Chain.layers Base.getindex, Base.first, Base.last
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
back!(s::Chain, ) = foldr((m, ) -> back!(m, ), , s.layers)
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
2016-08-31 05:29:52 +00:00
function Base.show(io::IO, c::Chain)
print(io, "Chain(")
print_joined(io, c.layers, ", ")
print(io, ")")
end
2016-08-25 21:49:21 +00:00
graph(s::Chain) =
foldl((v, m) -> vertex(m, v), constant(ModelInput(1)), s.layers)