Add x to seen in prefor to avoid infinite recursion if passed something self-referential
This commit is contained in:
parent
b3bba4c566
commit
16d5f2bc24
@ -31,6 +31,7 @@ end
|
|||||||
|
|
||||||
function prefor(f, x; seen = IdSet())
|
function prefor(f, x; seen = IdSet())
|
||||||
x ∈ seen && return
|
x ∈ seen && return
|
||||||
|
push!(seen, x)
|
||||||
f(x)
|
f(x)
|
||||||
foreach(x -> prefor(f, x, seen = seen), children(x))
|
foreach(x -> prefor(f, x, seen = seen), children(x))
|
||||||
return
|
return
|
||||||
|
@ -85,6 +85,17 @@ end
|
|||||||
@test size.(params(m)) == [(5, 10), (5,)]
|
@test size.(params(m)) == [(5, 10), (5,)]
|
||||||
m = RNN(10, 5)
|
m = RNN(10, 5)
|
||||||
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||||
|
|
||||||
|
# Layer duplicated in same chain, params just once pls.
|
||||||
|
c = Chain(m, m)
|
||||||
|
@test size.(params(c)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||||
|
|
||||||
|
# Recursive struct. Just want params, no stack overflow pls.
|
||||||
|
mutable struct R m;r end
|
||||||
|
Flux.@treelike R
|
||||||
|
r = R(m, nothing)
|
||||||
|
r.r = r
|
||||||
|
@test size.(params(r)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Basic Stacking" begin
|
@testset "Basic Stacking" begin
|
||||||
|
Loading…
Reference in New Issue
Block a user