simplify back! of Chain

This commit is contained in:
ylxdzsw 2017-06-14 21:58:37 +08:00
parent cca21a617c
commit c9ae219613
2 changed files with 7 additions and 13 deletions

View File

@ -21,4 +21,4 @@ function update!(m::Affine, η)
update!(m.W, η)
update!(m.b, η)
m
end
end

View File

@ -9,21 +9,15 @@ end
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
function back!(s::Chain, Δ, xs...)
crumbs = Tuple[xs]
N = length(s.layers)
for i in 1:N-1
xs = s.layers[i](xs...)
xs isa Tuple || (xs = (xs, ))
push!(crumbs, xs)
function back!(s::Chain, Δ, x)
crumbs = foldl([x], s.layers[1:end-1]) do crumbs, layer
push!(crumbs, layer(crumbs[end]))
end
for i in N:-1:1
Δ = back!(s.layers[i], Δ, crumbs[i]...)
foldr(Δ, collect(zip(crumbs, s.layers))) do pack, Δ
x, layer = pack
back!(layer, Δ, x)
end
Δ
end
graph(s::Chain) =