checkpoints

This commit is contained in:
Mike J Innes 2018-07-09 17:52:34 +01:00
parent 7778d17884
commit 1430053b69
2 changed files with 29 additions and 1 deletions

View File

@ -87,6 +87,22 @@ the sign of the gradient applied to `x`.
hook(f, x) = istracked(x) ? track(hook, f, x) : x
@grad hook(f, x) = x, Δ -> (nothing, f(Δ))
"""
checkpoint(f, args...)
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
calculating gradients. Instead, `f(args...)` will be called again during the
backward pass. This can be used to save memory in larger models.
"""
checkpoint(f, args...) = track(checkpoint, f, args...)
@grad function checkpoint(f, args...)
data(f(args...)), function (Δ)
y, back = forward(f, args...)
(nothing, back(Δ)...)
end
end
param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs))

View File

@ -1,5 +1,5 @@
using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using NNlib: conv
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
@ -241,4 +241,16 @@ end
@test grad(x) == -1
end
@testset "Checkpointing" begin
count = 0
function mul(a, b)
count += 1
a * b
end
@test derivative(x -> mul(5, x), 3) == 5
@test count == 1
@test derivative(x -> checkpoint(mul, 5, x), 3) == 5
@test count == 3
end
end #testset