diff --git a/src/core.jl b/src/core.jl index d3953849..66e33440 100644 --- a/src/core.jl +++ b/src/core.jl @@ -6,11 +6,11 @@ module FluxCore """ back!(model, ΔY, X...) => ΔX -Backpropagate the gradient `ΔY` through the model `m`, accumulating the +Backpropagate the gradient `ΔY` through the model `model`, accumulating the gradients of any parameters. Returns the gradient of the input `X`. Gradients may be arrays or tuples of arrays (for multiple inputs/outputs). """ -back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(m))") +back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(model))") """ update!(model, η) => m diff --git a/src/layers/affine.jl b/src/layers/affine.jl index 9608efcc..a1df4562 100644 --- a/src/layers/affine.jl +++ b/src/layers/affine.jl @@ -9,3 +9,16 @@ Affine(in::Integer, out::Integer; init = initn) = inferred(::Type{Affine}, in::Tuple{Dims{2}}, out::Integer) = Affine(in[1][2], out) + +function back!(m::Affine, Δ, x) + W, b = m.W, m.b + W.Δx[:] = x' * Δ + b.Δx[:] = sum(Δ, 1) + Δ * W.x' +end + +function update!(m::Affine, η) + update!(m.W, η) + update!(m.b, η) + m +end \ No newline at end of file diff --git a/src/layers/control.jl b/src/layers/control.jl index 7851f902..a08cb3cb 100644 --- a/src/layers/control.jl +++ b/src/layers/control.jl @@ -7,9 +7,25 @@ end @forward Chain.layers Base.start, Base.next, Base.done (s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) -back!(s::Chain, Δ) = foldr((m, Δ) -> back!(m, Δ), Δ, s.layers) update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers) +function back!(s::Chain, Δ, xs...) + crumbs = Tuple[xs] + N = length(s.layers) + + for i in 1:N-1 + xs = s.layers[i](xs...) + xs isa Tuple || (xs = (xs, )) + push!(crumbs, xs) + end + + for i in N:-1:1 + Δ = back!(s.layers[i], Δ, crumbs[i]...) + end + + Δ +end + graph(s::Chain) = foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)