From 4ccbbbb2841d3bfdd36e3557f4c428406eb81b2d Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 8 Jun 2017 10:49:39 +0100 Subject: [PATCH] dup params fix --- src/backend/mxnet/graph.jl | 10 +++++----- test/backend/mxnet.jl | 8 ++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 4e3eae01..5e532864 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -91,9 +91,8 @@ end register(ctx::Context, node) = node function var(ctx::Context, p::Union{Flux.Param{<:AbstractArray},AbstractArray,AlterParam}) - id = gensym() - ctx[:params][id] = p - return mx.Variable(id) + haskey(ctx[:params], p) && return ctx[:params][p] + ctx[:params][p] = mx.Variable(gensym()) end var(ctx::Context, x) = x @@ -110,10 +109,11 @@ graph′(ctx::Context, args...) = @icatch ctx graph(ctx, args...) function tograph(model, args...; feedforward = false) ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, graph′), - params = Dict(), stacks = Dict(), + params = ObjectIdDict(), stacks = Dict(), feedforward = feedforward) out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...) - return Graph(args, out, ctx[:params], ctx[:stacks]) + params = Dict(nodename(v) => p for (p, v) in ctx[:params]) + return Graph(args, out, params, ctx[:stacks]) end # Error Handling diff --git a/test/backend/mxnet.jl b/test/backend/mxnet.jl index f0aeb0b7..cad5407e 100644 --- a/test/backend/mxnet.jl +++ b/test/backend/mxnet.jl @@ -28,4 +28,12 @@ using Flux: MaxPool @test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)] end +@testset "Duplicate parameters" begin + a = Affine(10, 10) + d = Chain(a, a) + m = mxnet(d) + m(randn(1, 10)) + @test length(m.graph.params) == 2 +end + end