Merge remote-tracking branch 'upstream/master' into add-more-tf-ops

This commit is contained in:
Ali Hamdi 2017-06-08 11:59:59 +02:00
commit c350bfb672
4 changed files with 15 additions and 7 deletions

View File

@ -91,9 +91,8 @@ end
register(ctx::Context, node) = node register(ctx::Context, node) = node
function var(ctx::Context, p::Union{Flux.Param{<:AbstractArray},AbstractArray,AlterParam}) function var(ctx::Context, p::Union{Flux.Param{<:AbstractArray},AbstractArray,AlterParam})
id = gensym() haskey(ctx[:params], p) && return ctx[:params][p]
ctx[:params][id] = p ctx[:params][p] = mx.Variable(gensym())
return mx.Variable(id)
end end
var(ctx::Context, x) = x var(ctx::Context, x) = x
@ -110,10 +109,11 @@ graph(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
function tograph(model, args...; feedforward = false) function tograph(model, args...; feedforward = false)
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, graph), ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, graph),
params = Dict(), stacks = Dict(), params = ObjectIdDict(), stacks = Dict(),
feedforward = feedforward) feedforward = feedforward)
out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...) out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)
return Graph(args, out, ctx[:params], ctx[:stacks]) params = Dict(nodename(v) => p for (p, v) in ctx[:params])
return Graph(args, out, params, ctx[:stacks])
end end
# Error Handling # Error Handling

View File

@ -1,7 +1,7 @@
module TF module TF
using ..Flux, DataFlow, TensorFlow, Juno using ..Flux, DataFlow, TensorFlow, Juno
import Flux: accuracy, rebatch, convertel import Flux: accuracy, convertel
export tf export tf

View File

@ -39,7 +39,7 @@ function test_back(bk)
update!(dm, 0.1) update!(dm, 0.1)
@test dm(xs) d(xs) @test dm(xs) d(xs)
@test dm(xs) d(xs) @test !(dm(xs) d(xs))
end end
end end

View File

@ -28,4 +28,12 @@ using Flux: MaxPool
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)] @test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
end end
@testset "Duplicate parameters" begin
a = Affine(10, 10)
d = Chain(a, a)
m = mxnet(d)
m(randn(1, 10))
@test length(m.graph.params) == 2
end
end end