Merge remote-tracking branch 'upstream/master' into add-more-tf-ops
This commit is contained in:
commit
c350bfb672
@ -91,9 +91,8 @@ end
|
|||||||
register(ctx::Context, node) = node
|
register(ctx::Context, node) = node
|
||||||
|
|
||||||
function var(ctx::Context, p::Union{Flux.Param{<:AbstractArray},AbstractArray,AlterParam})
|
function var(ctx::Context, p::Union{Flux.Param{<:AbstractArray},AbstractArray,AlterParam})
|
||||||
id = gensym()
|
haskey(ctx[:params], p) && return ctx[:params][p]
|
||||||
ctx[:params][id] = p
|
ctx[:params][p] = mx.Variable(gensym())
|
||||||
return mx.Variable(id)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
var(ctx::Context, x) = x
|
var(ctx::Context, x) = x
|
||||||
@ -110,10 +109,11 @@ graph′(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
|
|||||||
|
|
||||||
function tograph(model, args...; feedforward = false)
|
function tograph(model, args...; feedforward = false)
|
||||||
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, graph′),
|
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, graph′),
|
||||||
params = Dict(), stacks = Dict(),
|
params = ObjectIdDict(), stacks = Dict(),
|
||||||
feedforward = feedforward)
|
feedforward = feedforward)
|
||||||
out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)
|
out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)
|
||||||
return Graph(args, out, ctx[:params], ctx[:stacks])
|
params = Dict(nodename(v) => p for (p, v) in ctx[:params])
|
||||||
|
return Graph(args, out, params, ctx[:stacks])
|
||||||
end
|
end
|
||||||
|
|
||||||
# Error Handling
|
# Error Handling
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
module TF
|
module TF
|
||||||
|
|
||||||
using ..Flux, DataFlow, TensorFlow, Juno
|
using ..Flux, DataFlow, TensorFlow, Juno
|
||||||
import Flux: accuracy, rebatch, convertel
|
import Flux: accuracy, convertel
|
||||||
|
|
||||||
export tf
|
export tf
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ function test_back(bk)
|
|||||||
update!(dm, 0.1)
|
update!(dm, 0.1)
|
||||||
|
|
||||||
@test dm(xs) ≈ d(xs)
|
@test dm(xs) ≈ d(xs)
|
||||||
@test dm(xs) ≉ d′(xs)
|
@test !(dm(xs) ≈ d′(xs))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -28,4 +28,12 @@ using Flux: MaxPool
|
|||||||
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Duplicate parameters" begin
|
||||||
|
a = Affine(10, 10)
|
||||||
|
d = Chain(a, a)
|
||||||
|
m = mxnet(d)
|
||||||
|
m(randn(1, 10))
|
||||||
|
@test length(m.graph.params) == 2
|
||||||
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user