faster default gradient performance
This commit is contained in:
parent
95fb46018d
commit
b3331205d1
|
@ -5,7 +5,8 @@ using MacroTools: @q, @forward
|
|||
|
||||
import Base: ==
|
||||
|
||||
export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back!
|
||||
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
|
||||
param, back!
|
||||
|
||||
tracker(x) = nothing
|
||||
|
||||
|
@ -99,7 +100,8 @@ end
|
|||
|
||||
nobacksies(f, x) = track(nobacksies, f, x)
|
||||
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
|
||||
@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f")
|
||||
@grad nobacksies(f::Symbol, x) = data(x), Δ -> error("Nested AD not defined for $f")
|
||||
@grad nobacksies(f::String, x) = data(x), Δ -> error(f)
|
||||
|
||||
param(x::Number) = TrackedReal(float(x))
|
||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||
|
|
|
@ -66,6 +66,15 @@ function back!(x, Δ; once = true)
|
|||
return
|
||||
end
|
||||
|
||||
function gradient_(f, xs...)
|
||||
xs = param.(xs)
|
||||
l = f(xs...)
|
||||
losscheck(l)
|
||||
back!(l)
|
||||
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
|
||||
grad.(xs))
|
||||
end
|
||||
|
||||
# Out-of-place gradients
|
||||
|
||||
struct Params
|
||||
|
@ -162,20 +171,11 @@ function losscheck(x)
|
|||
isnan(x) && error("Loss is NaN")
|
||||
end
|
||||
|
||||
function gradient(f, args...)
|
||||
function gradient_nested(f, args...)
|
||||
y, back = forward(f, args...)
|
||||
losscheck(y)
|
||||
return back(1)
|
||||
end
|
||||
|
||||
derivative(f, x) = gradient(f, x)[1]
|
||||
|
||||
# Non-nesting versions
|
||||
|
||||
function gradient_(f, xs...)
|
||||
xs = param.(xs)
|
||||
l = f(xs...)
|
||||
losscheck(l)
|
||||
back!(l)
|
||||
grad.(xs)
|
||||
end
|
||||
gradient(f, xs...; nest = false) =
|
||||
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
using Flux
|
||||
using Flux.Tracker, Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad, checkpoint
|
||||
using NNlib: conv, depthwiseconv
|
||||
using Printf: @sprintf
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||
|
@ -285,9 +285,9 @@ end
|
|||
count += 1
|
||||
a * b
|
||||
end
|
||||
@test derivative(x -> mul(5, x), 3) == 5
|
||||
@test gradient(x -> mul(5, x), 3)[1] == 5
|
||||
@test count == 1
|
||||
@test derivative(x -> checkpoint(mul, 5, x), 3) == 5
|
||||
@test gradient(x -> checkpoint(mul, 5, x), 3)[1] == 5
|
||||
@test count == 3
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue