Add x to seen in prefor to avoid infinite recursion if passed something self-referential

This commit is contained in:
DrChainsaw 2019-07-08 23:11:35 +02:00
parent b3bba4c566
commit 16d5f2bc24
2 changed files with 12 additions and 0 deletions

View File

@ -31,6 +31,7 @@ end
function prefor(f, x; seen = IdSet())
x seen && return
push!(seen, x)
f(x)
foreach(x -> prefor(f, x, seen = seen), children(x))
return

View File

@ -85,6 +85,17 @@ end
@test size.(params(m)) == [(5, 10), (5,)]
m = RNN(10, 5)
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
# Layer duplicated in same chain, params just once pls.
c = Chain(m, m)
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5,)]
# Recursive struct. Just want params, no stack overflow pls.
mutable struct R m;r end
Flux.@treelike R
r = R(m, nothing)
r.r = r
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
end
@testset "Basic Stacking" begin