faster default gradient performance

This commit is contained in:
Mike J Innes 2018-11-12 23:39:25 +00:00
parent 95fb46018d
commit b3331205d1
3 changed files with 19 additions and 17 deletions

View File

@ -5,7 +5,8 @@ using MacroTools: @q, @forward
import Base: ==
export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back!
export TrackedArray, TrackedVector, TrackedMatrix, Params, gradient,
param, back!
tracker(x) = nothing
@ -99,7 +100,8 @@ end
nobacksies(f, x) = track(nobacksies, f, x)
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f")
@grad nobacksies(f::Symbol, x) = data(x), Δ -> error("Nested AD not defined for $f")
@grad nobacksies(f::String, x) = data(x), Δ -> error(f)
param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs))

View File

@ -66,6 +66,15 @@ function back!(x, Δ; once = true)
return
end
function gradient_(f, xs...)
xs = param.(xs)
l = f(xs...)
losscheck(l)
back!(l)
nobacksies("Use `gradient(...; nest = true)` for nested derivatives",
grad.(xs))
end
# Out-of-place gradients
struct Params
@ -162,20 +171,11 @@ function losscheck(x)
isnan(x) && error("Loss is NaN")
end
function gradient(f, args...)
function gradient_nested(f, args...)
y, back = forward(f, args...)
losscheck(y)
return back(1)
end
derivative(f, x) = gradient(f, x)[1]
# Non-nesting versions
function gradient_(f, xs...)
xs = param.(xs)
l = f(xs...)
losscheck(l)
back!(l)
grad.(xs)
end
gradient(f, xs...; nest = false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)

View File

@ -1,6 +1,6 @@
using Flux
using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using Flux.Tracker: TrackedReal, gradcheck, grad, checkpoint
using NNlib: conv, depthwiseconv
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm
@ -285,9 +285,9 @@ end
count += 1
a * b
end
@test derivative(x -> mul(5, x), 3) == 5
@test gradient(x -> mul(5, x), 3)[1] == 5
@test count == 1
@test derivative(x -> checkpoint(mul, 5, x), 3) == 5
@test gradient(x -> checkpoint(mul, 5, x), 3)[1] == 5
@test count == 3
end