checkpoints
This commit is contained in:
parent
7778d17884
commit
1430053b69
|
@ -87,6 +87,22 @@ the sign of the gradient applied to `x`.
|
|||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||
@grad hook(f, x) = x, Δ -> (nothing, f(Δ))
|
||||
|
||||
"""
|
||||
checkpoint(f, args...)
|
||||
|
||||
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
|
||||
calculating gradients. Instead, `f(args...)` will be called again during the
|
||||
backward pass. This can be used to save memory in larger models.
|
||||
"""
|
||||
checkpoint(f, args...) = track(checkpoint, f, args...)
|
||||
|
||||
@grad function checkpoint(f, args...)
|
||||
data(f(args...)), function (Δ)
|
||||
y, back = forward(f, args...)
|
||||
(nothing, back(Δ)...)
|
||||
end
|
||||
end
|
||||
|
||||
param(x::Number) = TrackedReal(float(x))
|
||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Flux.Tracker, Base.Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
||||
using NNlib: conv
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
|
@ -241,4 +241,16 @@ end
|
|||
@test grad(x) == -1
|
||||
end
|
||||
|
||||
@testset "Checkpointing" begin
|
||||
count = 0
|
||||
function mul(a, b)
|
||||
count += 1
|
||||
a * b
|
||||
end
|
||||
@test derivative(x -> mul(5, x), 3) == 5
|
||||
@test count == 1
|
||||
@test derivative(x -> checkpoint(mul, 5, x), 3) == 5
|
||||
@test count == 3
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue