diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 8f9da6ff..d461c95c 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -101,7 +101,7 @@ struct Diagonal{T} β::T end -Diagonal(in::Integer; initα = ones, initβ = (x) -> similar(x) .= 0) = +Diagonal(in::Integer; initα = ones, initβ = zeros) = Diagonal(param(initα(in)), param(initβ(in))) @treelike Diagonal