fixes #516
This commit is contained in:
parent
67d9016319
commit
a3e0de1ee5
|
@ -67,7 +67,7 @@ function back!(x, Δ; once = true)
|
|||
end
|
||||
|
||||
function gradient_(f, xs...)
|
||||
xs = param.(xs)
|
||||
xs = param.(data.(xs))
|
||||
l = f(xs...)
|
||||
losscheck(l)
|
||||
back!(l)
|
||||
|
@ -179,3 +179,5 @@ end
|
|||
|
||||
gradient(f, xs...; nest = false) =
|
||||
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
||||
|
||||
gradient(f, ps::Params) = gradient_nested(f, ps)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
using Flux
|
||||
using Flux.Tracker, Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad, checkpoint
|
||||
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint
|
||||
using NNlib: conv, depthwiseconv
|
||||
using Printf: @sprintf
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||
|
@ -260,7 +260,7 @@ Tracker.back!(b)
|
|||
back!(z)
|
||||
@test grad.((x,y)) == (3, 2)
|
||||
|
||||
@test Tracker.gradient(2, 3) do x, y
|
||||
@test gradient(2, 3) do x, y
|
||||
xy = Tracker.collect([x, y])
|
||||
xy[1]*xy[2]
|
||||
end == (3, 2)
|
||||
|
@ -295,4 +295,12 @@ end
|
|||
@test x == 7
|
||||
end
|
||||
|
||||
@testset "Params" begin
|
||||
W = param(randn(5, 10))
|
||||
x = rand(10)
|
||||
dW = gradient(W -> sum(W*x), W)[1]
|
||||
gs = gradient(() -> sum(W*x), Tracker.Params([W]))
|
||||
@test gs[W] == dW
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue