From 3e0f45046cabd0996fba2a6a650514a47ecce65a Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 21 Aug 2017 17:20:09 +0100 Subject: [PATCH] nicer show --- src/layers/basic.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 59349aba..a6a8bd62 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -15,6 +15,12 @@ Compiler.graph(s::Chain) = Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) +function Base.show(io::IO, c::Chain) + print(io, "Chain(") + join(io, c.layers, ", ") + print(io, ")") +end + # Linear struct Linear{F,S,T} @@ -27,3 +33,9 @@ Linear(in::Integer, out::Integer, σ = identity; init = initn) = Linear(σ, track(init(out, in)), track(init(out))) (a::Linear)(x) = a.σ.(a.W*x .+ a.b) + +function Base.show(io::IO, l::Linear) + print(io, "Linear(", size(l.W, 2), ", ", size(l.W, 1)) + l.σ == identity || print(io, ", ", l.σ) + print(io, ")") +end