Add x to seen in prefor to avoid infinite recursion if passed something self-referential
This commit is contained in:
parent
b3bba4c566
commit
16d5f2bc24
|
@ -31,6 +31,7 @@ end
|
|||
|
||||
function prefor(f, x; seen = IdSet())
|
||||
x ∈ seen && return
|
||||
push!(seen, x)
|
||||
f(x)
|
||||
foreach(x -> prefor(f, x, seen = seen), children(x))
|
||||
return
|
||||
|
|
|
@ -85,6 +85,17 @@ end
|
|||
@test size.(params(m)) == [(5, 10), (5,)]
|
||||
m = RNN(10, 5)
|
||||
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
|
||||
# Layer duplicated in same chain, params just once pls.
|
||||
c = Chain(m, m)
|
||||
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
|
||||
# Recursive struct. Just want params, no stack overflow pls.
|
||||
mutable struct R m;r end
|
||||
Flux.@treelike R
|
||||
r = R(m, nothing)
|
||||
r.r = r
|
||||
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
end
|
||||
|
||||
@testset "Basic Stacking" begin
|
||||
|
|
Loading…
Reference in New Issue