test params funct

This commit is contained in:
Mike J Innes 2018-02-08 16:13:20 +00:00
parent 961de2ba44
commit 0f7a1ec022
3 changed files with 12 additions and 4 deletions

View File

@ -33,11 +33,12 @@ function prefor(f, x; seen = OSet())
return
end
using Flux.Tracker: istracked
function params(m)
ps = []
prefor(p -> istracked(p) && push!(ps, p), m)
prefor(p ->
Tracker.istracked(p) && Tracker.isleaf(p) &&
!(p in ps) && push!(ps, p),
m)
return ps
end

View File

@ -3,7 +3,7 @@ using Flux.Tracker
@testset "Optimise" begin
w = randn(10, 10)
for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
w = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w*x)
opt = Opt([w])

View File

@ -79,3 +79,10 @@ end
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
end
end
@testset "Params" begin
m = Dense(10, 5)
@test size.(params(m)) == [(5, 10), (5,)]
m = RNN(10, 5)
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
end