simplify back! of Chain

This commit is contained in:
ylxdzsw 2017-06-14 21:58:37 +08:00
parent cca21a617c
commit c9ae219613
2 changed files with 7 additions and 13 deletions

View File

@ -21,4 +21,4 @@ function update!(m::Affine, η)
update!(m.W, η) update!(m.W, η)
update!(m.b, η) update!(m.b, η)
m m
end end

View File

@ -9,21 +9,15 @@ end
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) (s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers) update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
function back!(s::Chain, Δ, xs...) function back!(s::Chain, Δ, x)
crumbs = Tuple[xs] crumbs = foldl([x], s.layers[1:end-1]) do crumbs, layer
N = length(s.layers) push!(crumbs, layer(crumbs[end]))
for i in 1:N-1
xs = s.layers[i](xs...)
xs isa Tuple || (xs = (xs, ))
push!(crumbs, xs)
end end
for i in N:-1:1 foldr(Δ, collect(zip(crumbs, s.layers))) do pack, Δ
Δ = back!(s.layers[i], Δ, crumbs[i]...) x, layer = pack
back!(layer, Δ, x)
end end
Δ
end end
graph(s::Chain) = graph(s::Chain) =