fixes #516
This commit is contained in:
parent
67d9016319
commit
a3e0de1ee5
@ -67,7 +67,7 @@ function back!(x, Δ; once = true)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function gradient_(f, xs...)
|
function gradient_(f, xs...)
|
||||||
xs = param.(xs)
|
xs = param.(data.(xs))
|
||||||
l = f(xs...)
|
l = f(xs...)
|
||||||
losscheck(l)
|
losscheck(l)
|
||||||
back!(l)
|
back!(l)
|
||||||
@ -179,3 +179,5 @@ end
|
|||||||
|
|
||||||
gradient(f, xs...; nest = false) =
|
gradient(f, xs...; nest = false) =
|
||||||
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)
|
||||||
|
|
||||||
|
gradient(f, ps::Params) = gradient_nested(f, ps)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
using Flux
|
using Flux
|
||||||
using Flux.Tracker, Test, NNlib
|
using Flux.Tracker, Test, NNlib
|
||||||
using Flux.Tracker: TrackedReal, gradcheck, grad, checkpoint
|
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint
|
||||||
using NNlib: conv, depthwiseconv
|
using NNlib: conv, depthwiseconv
|
||||||
using Printf: @sprintf
|
using Printf: @sprintf
|
||||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||||
@ -260,7 +260,7 @@ Tracker.back!(b)
|
|||||||
back!(z)
|
back!(z)
|
||||||
@test grad.((x,y)) == (3, 2)
|
@test grad.((x,y)) == (3, 2)
|
||||||
|
|
||||||
@test Tracker.gradient(2, 3) do x, y
|
@test gradient(2, 3) do x, y
|
||||||
xy = Tracker.collect([x, y])
|
xy = Tracker.collect([x, y])
|
||||||
xy[1]*xy[2]
|
xy[1]*xy[2]
|
||||||
end == (3, 2)
|
end == (3, 2)
|
||||||
@ -295,4 +295,12 @@ end
|
|||||||
@test x == 7
|
@test x == 7
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Params" begin
|
||||||
|
W = param(randn(5, 10))
|
||||||
|
x = rand(10)
|
||||||
|
dW = gradient(W -> sum(W*x), W)[1]
|
||||||
|
gs = gradient(() -> sum(W*x), Tracker.Params([W]))
|
||||||
|
@test gs[W] == dW
|
||||||
|
end
|
||||||
|
|
||||||
end #testset
|
end #testset
|
||||||
|
Loading…
Reference in New Issue
Block a user