checkpoints
This commit is contained in:
parent
7778d17884
commit
1430053b69
@ -87,6 +87,22 @@ the sign of the gradient applied to `x`.
|
|||||||
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
hook(f, x) = istracked(x) ? track(hook, f, x) : x
|
||||||
@grad hook(f, x) = x, Δ -> (nothing, f(Δ))
|
@grad hook(f, x) = x, Δ -> (nothing, f(Δ))
|
||||||
|
|
||||||
|
"""
|
||||||
|
checkpoint(f, args...)
|
||||||
|
|
||||||
|
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
|
||||||
|
calculating gradients. Instead, `f(args...)` will be called again during the
|
||||||
|
backward pass. This can be used to save memory in larger models.
|
||||||
|
"""
|
||||||
|
checkpoint(f, args...) = track(checkpoint, f, args...)
|
||||||
|
|
||||||
|
@grad function checkpoint(f, args...)
|
||||||
|
data(f(args...)), function (Δ)
|
||||||
|
y, back = forward(f, args...)
|
||||||
|
(nothing, back(Δ)...)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
param(x::Number) = TrackedReal(float(x))
|
param(x::Number) = TrackedReal(float(x))
|
||||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
using Flux.Tracker, Base.Test, NNlib
|
using Flux.Tracker, Base.Test, NNlib
|
||||||
using Flux.Tracker: TrackedReal, gradcheck, grad
|
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
||||||
using NNlib: conv
|
using NNlib: conv
|
||||||
|
|
||||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||||
@ -241,4 +241,16 @@ end
|
|||||||
@test grad(x) == -1
|
@test grad(x) == -1
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Checkpointing" begin
|
||||||
|
count = 0
|
||||||
|
function mul(a, b)
|
||||||
|
count += 1
|
||||||
|
a * b
|
||||||
|
end
|
||||||
|
@test derivative(x -> mul(5, x), 3) == 5
|
||||||
|
@test count == 1
|
||||||
|
@test derivative(x -> checkpoint(mul, 5, x), 3) == 5
|
||||||
|
@test count == 3
|
||||||
|
end
|
||||||
|
|
||||||
end #testset
|
end #testset
|
||||||
|
Loading…
Reference in New Issue
Block a user