This commit is contained in:
Mike J Innes 2019-01-15 15:48:38 +00:00
parent 67d9016319
commit a3e0de1ee5
2 changed files with 13 additions and 3 deletions

View File

@ -67,7 +67,7 @@ function back!(x, Δ; once = true)
end
function gradient_(f, xs...)
xs = param.(xs)
xs = param.(data.(xs))
l = f(xs...)
losscheck(l)
back!(l)
@ -179,3 +179,5 @@ end
gradient(f, xs...; nest = false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
gradient(f, ps::Params) = gradient_nested(f, ps)

View File

@ -1,6 +1,6 @@
using Flux
using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, checkpoint
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint
using NNlib: conv, depthwiseconv
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm
@ -260,7 +260,7 @@ Tracker.back!(b)
back!(z)
@test grad.((x,y)) == (3, 2)
@test Tracker.gradient(2, 3) do x, y
@test gradient(2, 3) do x, y
xy = Tracker.collect([x, y])
xy[1]*xy[2]
end == (3, 2)
@ -295,4 +295,12 @@ end
@test x == 7
end
@testset "Params" begin
W = param(randn(5, 10))
x = rand(10)
dW = gradient(W -> sum(W*x), W)[1]
gs = gradient(() -> sum(W*x), Tracker.Params([W]))
@test gs[W] == dW
end
end #testset