faster default gradient performance
This commit is contained in:
parent
95fb46018d
commit
b3331205d1
@ -5,7 +5,8 @@ using MacroTools: @q, @forward
|
|||||||
|
|
||||||
import Base: ==
|
import Base: ==
|
||||||
|
|
||||||
export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back!
|
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
|
||||||
|
param, back!
|
||||||
|
|
||||||
tracker(x) = nothing
|
tracker(x) = nothing
|
||||||
|
|
||||||
@ -99,7 +100,8 @@ end
|
|||||||
|
|
||||||
nobacksies(f, x) = track(nobacksies, f, x)
|
nobacksies(f, x) = track(nobacksies, f, x)
|
||||||
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
|
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
|
||||||
@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f")
|
@grad nobacksies(f::Symbol, x) = data(x), Δ -> error("Nested AD not defined for $f")
|
||||||
|
@grad nobacksies(f::String, x) = data(x), Δ -> error(f)
|
||||||
|
|
||||||
param(x::Number) = TrackedReal(float(x))
|
param(x::Number) = TrackedReal(float(x))
|
||||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||||
|
@ -66,6 +66,15 @@ function back!(x, Δ; once = true)
|
|||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function gradient_(f, xs...)
|
||||||
|
xs = param.(xs)
|
||||||
|
l = f(xs...)
|
||||||
|
losscheck(l)
|
||||||
|
back!(l)
|
||||||
|
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
|
||||||
|
grad.(xs))
|
||||||
|
end
|
||||||
|
|
||||||
# Out-of-place gradients
|
# Out-of-place gradients
|
||||||
|
|
||||||
struct Params
|
struct Params
|
||||||
@ -162,20 +171,11 @@ function losscheck(x)
|
|||||||
isnan(x) && error("Loss is NaN")
|
isnan(x) && error("Loss is NaN")
|
||||||
end
|
end
|
||||||
|
|
||||||
function gradient(f, args...)
|
function gradient_nested(f, args...)
|
||||||
y, back = forward(f, args...)
|
y, back = forward(f, args...)
|
||||||
losscheck(y)
|
losscheck(y)
|
||||||
return back(1)
|
return back(1)
|
||||||
end
|
end
|
||||||
|
|
||||||
derivative(f, x) = gradient(f, x)[1]
|
gradient(f, xs...; nest = false) =
|
||||||
|
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
||||||
# Non-nesting versions
|
|
||||||
|
|
||||||
function gradient_(f, xs...)
|
|
||||||
xs = param.(xs)
|
|
||||||
l = f(xs...)
|
|
||||||
losscheck(l)
|
|
||||||
back!(l)
|
|
||||||
grad.(xs)
|
|
||||||
end
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
using Flux
|
using Flux
|
||||||
using Flux.Tracker, Test, NNlib
|
using Flux.Tracker, Test, NNlib
|
||||||
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
using Flux.Tracker: TrackedReal, gradcheck, grad, checkpoint
|
||||||
using NNlib: conv, depthwiseconv
|
using NNlib: conv, depthwiseconv
|
||||||
using Printf: @sprintf
|
using Printf: @sprintf
|
||||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||||
@ -285,9 +285,9 @@ end
|
|||||||
count += 1
|
count += 1
|
||||||
a * b
|
a * b
|
||||||
end
|
end
|
||||||
@test derivative(x -> mul(5, x), 3) == 5
|
@test gradient(x -> mul(5, x), 3)[1] == 5
|
||||||
@test count == 1
|
@test count == 1
|
||||||
@test derivative(x -> checkpoint(mul, 5, x), 3) == 5
|
@test gradient(x -> checkpoint(mul, 5, x), 3)[1] == 5
|
||||||
@test count == 3
|
@test count == 3
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user