diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 8d116480..60634634 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -8,7 +8,7 @@ type MXModel <: Model end Base.show(io::IO, m::MXModel) = - print(io, "MXModel($(unblock(syntax(Flux.graph(m.model)))))") + print(io, "MXModel($(m.model))") mxdims(dims::NTuple) = length(dims) == 1 ? (1, dims...) : reverse(dims) diff --git a/src/layers/chain.jl b/src/layers/chain.jl index 906e19c8..74f4a757 100644 --- a/src/layers/chain.jl +++ b/src/layers/chain.jl @@ -26,5 +26,11 @@ end back!(s::Chain, ∇) = foldr((m, ∇) -> back!(m, ∇), ∇, s.layers) update!(s::Chain, η) = foreach(l -> update!(l, η), s.layers) +function Base.show(io::IO, c::Chain) + print(io, "Chain(") + print_joined(io, c.layers, ", ") + print(io, ")") +end + graph(s::Chain) = foldl((v, m) -> vertex(m, v), constant(ModelInput(1)), s.layers) diff --git a/src/layers/dense.jl b/src/layers/dense.jl index e2c3682e..30b010e2 100644 --- a/src/layers/dense.jl +++ b/src/layers/dense.jl @@ -12,7 +12,7 @@ Dense(in::Integer, out::Integer; init = initn) = Dense(init(out, in), init(out)) Base.show(io::IO, d::Dense) = - print(io, "Flux.Dense($(size(d.W.x,2)),$(size(d.W.x,1)))") + print(io, "Dense($(size(d.W.x,2)),$(size(d.W.x,1)))") @model type Sigmoid layer::Model