training julia models
This commit is contained in:
parent
358ba650ad
commit
cca21a617c
@ -6,11 +6,11 @@ module FluxCore
|
||||
"""
|
||||
back!(model, ΔY, X...) => ΔX
|
||||
|
||||
Backpropagate the gradient `ΔY` through the model `m`, accumulating the
|
||||
Backpropagate the gradient `ΔY` through the model `model`, accumulating the
|
||||
gradients of any parameters. Returns the gradient of the input `X`. Gradients
|
||||
may be arrays or tuples of arrays (for multiple inputs/outputs).
|
||||
"""
|
||||
back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(m))")
|
||||
back!(model, Δ, xs...) = error("Backprop not implemented for $(typeof(model))")
|
||||
|
||||
"""
|
||||
update!(model, η) => m
|
||||
|
@ -9,3 +9,16 @@ Affine(in::Integer, out::Integer; init = initn) =
|
||||
|
||||
inferred(::Type{Affine}, in::Tuple{Dims{2}}, out::Integer) =
|
||||
Affine(in[1][2], out)
|
||||
|
||||
function back!(m::Affine, Δ, x)
|
||||
W, b = m.W, m.b
|
||||
W.Δx[:] = x' * Δ
|
||||
b.Δx[:] = sum(Δ, 1)
|
||||
Δ * W.x'
|
||||
end
|
||||
|
||||
function update!(m::Affine, η)
|
||||
update!(m.W, η)
|
||||
update!(m.b, η)
|
||||
m
|
||||
end
|
@ -7,9 +7,25 @@ end
|
||||
@forward Chain.layers Base.start, Base.next, Base.done
|
||||
|
||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
||||
back!(s::Chain, Δ) = foldr((m, Δ) -> back!(m, Δ), Δ, s.layers)
|
||||
update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers)
|
||||
|
||||
function back!(s::Chain, Δ, xs...)
|
||||
crumbs = Tuple[xs]
|
||||
N = length(s.layers)
|
||||
|
||||
for i in 1:N-1
|
||||
xs = s.layers[i](xs...)
|
||||
xs isa Tuple || (xs = (xs, ))
|
||||
push!(crumbs, xs)
|
||||
end
|
||||
|
||||
for i in N:-1:1
|
||||
Δ = back!(s.layers[i], Δ, crumbs[i]...)
|
||||
end
|
||||
|
||||
Δ
|
||||
end
|
||||
|
||||
graph(s::Chain) =
|
||||
foldl((v, m) -> vertex(m, v), constant(inputnode(1)), s.layers)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user