test params funct
This commit is contained in:
parent
961de2ba44
commit
0f7a1ec022
|
@ -33,11 +33,12 @@ function prefor(f, x; seen = OSet())
|
|||
return
|
||||
end
|
||||
|
||||
using Flux.Tracker: istracked
|
||||
|
||||
function params(m)
|
||||
ps = []
|
||||
prefor(p -> istracked(p) && push!(ps, p), m)
|
||||
prefor(p ->
|
||||
Tracker.istracked(p) && Tracker.isleaf(p) &&
|
||||
!(p in ps) && push!(ps, p),
|
||||
m)
|
||||
return ps
|
||||
end
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ using Flux.Tracker
|
|||
|
||||
@testset "Optimise" begin
|
||||
w = randn(10, 10)
|
||||
for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
|
||||
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
|
||||
w′ = param(randn(10, 10))
|
||||
loss(x) = Flux.mse(w*x, w′*x)
|
||||
opt = Opt([w′])
|
||||
|
|
|
@ -79,3 +79,10 @@ end
|
|||
@test std(v) < 1.1*sqrt(2/(n_in + n_out))
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Params" begin
|
||||
m = Dense(10, 5)
|
||||
@test size.(params(m)) == [(5, 10), (5,)]
|
||||
m = RNN(10, 5)
|
||||
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue