From c9ae2196137082bbdb76dca00dbb001cd9422469 Mon Sep 17 00:00:00 2001 From: ylxdzsw Date: Wed, 14 Jun 2017 21:58:37 +0800 Subject: [PATCH] simplify `back!` of `Chain` --- src/layers/affine.jl | 2 +- src/layers/control.jl | 18 ++++++------------ 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/layers/affine.jl b/src/layers/affine.jl index a1df4562..ca79c004 100644 --- a/src/layers/affine.jl +++ b/src/layers/affine.jl @@ -21,4 +21,4 @@ function update!(m::Affine, η) update!(m.W, η) update!(m.b, η) m -end \ No newline at end of file +end diff --git a/src/layers/control.jl b/src/layers/control.jl index a08cb3cb..d0c5e61b 100644 --- a/src/layers/control.jl +++ b/src/layers/control.jl @@ -9,21 +9,15 @@ end (s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers) -function back!(s::Chain, Δ, xs...) - crumbs = Tuple[xs] - N = length(s.layers) - - for i in 1:N-1 - xs = s.layers[i](xs...) - xs isa Tuple || (xs = (xs, )) - push!(crumbs, xs) +function back!(s::Chain, Δ, x) + crumbs = foldl([x], s.layers[1:end-1]) do crumbs, layer + push!(crumbs, layer(crumbs[end])) end - for i in N:-1:1 - Δ = back!(s.layers[i], Δ, crumbs[i]...) + foldr(Δ, collect(zip(crumbs, s.layers))) do pack, Δ + x, layer = pack + back!(layer, Δ, x) end - - Δ end graph(s::Chain) =