nicer show
This commit is contained in:
parent
227e41c37b
commit
3e0f45046c
@ -15,6 +15,12 @@ Compiler.graph(s::Chain) =
|
|||||||
|
|
||||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||||
|
|
||||||
|
function Base.show(io::IO, c::Chain)
|
||||||
|
print(io, "Chain(")
|
||||||
|
join(io, c.layers, ", ")
|
||||||
|
print(io, ")")
|
||||||
|
end
|
||||||
|
|
||||||
# Linear
|
# Linear
|
||||||
|
|
||||||
struct Linear{F,S,T}
|
struct Linear{F,S,T}
|
||||||
@ -27,3 +33,9 @@ Linear(in::Integer, out::Integer, σ = identity; init = initn) =
|
|||||||
Linear(σ, track(init(out, in)), track(init(out)))
|
Linear(σ, track(init(out, in)), track(init(out)))
|
||||||
|
|
||||||
(a::Linear)(x) = a.σ.(a.W*x .+ a.b)
|
(a::Linear)(x) = a.σ.(a.W*x .+ a.b)
|
||||||
|
|
||||||
|
function Base.show(io::IO, l::Linear)
|
||||||
|
print(io, "Linear(", size(l.W, 2), ", ", size(l.W, 1))
|
||||||
|
l.σ == identity || print(io, ", ", l.σ)
|
||||||
|
print(io, ")")
|
||||||
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user