simplify back!
of Chain
This commit is contained in:
parent
cca21a617c
commit
c9ae219613
@ -9,21 +9,15 @@ end
|
|||||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||||
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
|
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
|
||||||
|
|
||||||
function back!(s::Chain, Δ, xs...)
|
function back!(s::Chain, Δ, x)
|
||||||
crumbs = Tuple[xs]
|
crumbs = foldl([x], s.layers[1:end-1]) do crumbs, layer
|
||||||
N = length(s.layers)
|
push!(crumbs, layer(crumbs[end]))
|
||||||
|
|
||||||
for i in 1:N-1
|
|
||||||
xs = s.layers[i](xs...)
|
|
||||||
xs isa Tuple || (xs = (xs, ))
|
|
||||||
push!(crumbs, xs)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
for i in N:-1:1
|
foldr(Δ, collect(zip(crumbs, s.layers))) do pack, Δ
|
||||||
Δ = back!(s.layers[i], Δ, crumbs[i]...)
|
x, layer = pack
|
||||||
|
back!(layer, Δ, x)
|
||||||
end
|
end
|
||||||
|
|
||||||
Δ
|
|
||||||
end
|
end
|
||||||
|
|
||||||
graph(s::Chain) =
|
graph(s::Chain) =
|
||||||
|
Loading…
Reference in New Issue
Block a user