support maps over tuples
This commit is contained in:
parent
d788cc8c54
commit
790a58be1d
@ -40,6 +40,8 @@ graph(::typeof(softmax), xs) =
|
|||||||
graph(::typeof(cat), dim::Integer, a...) = mx.Concat(a..., dim = dim)
|
graph(::typeof(cat), dim::Integer, a...) = mx.Concat(a..., dim = dim)
|
||||||
graph(::typeof(vcat), a...) = graph(cat, 1, a...)
|
graph(::typeof(vcat), a...) = graph(cat, 1, a...)
|
||||||
|
|
||||||
|
graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
|
||||||
|
|
||||||
graph(::Input, x) = x
|
graph(::Input, x) = x
|
||||||
|
|
||||||
graph(ctx::Context, d::Affine, x) =
|
graph(ctx::Context, d::Affine, x) =
|
||||||
|
@ -39,6 +39,8 @@ end
|
|||||||
|
|
||||||
graph(::typeof(.-), args...) = -(args...)
|
graph(::typeof(.-), args...) = -(args...)
|
||||||
|
|
||||||
|
graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
|
||||||
|
|
||||||
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
||||||
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
||||||
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)]))
|
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)]))
|
||||||
|
@ -55,3 +55,11 @@ function test_stacktrace(bk)
|
|||||||
@test e.trace[2].func == :TLP
|
@test e.trace[2].func == :TLP
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function test_anon(bk)
|
||||||
|
@testset "Closures" begin
|
||||||
|
x, y = rand(3), rand(5)
|
||||||
|
model = bk(@net xs -> map(x -> x .* x, xs))
|
||||||
|
@test all(model((x, y)) .≈ (x.*x, y.*y))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
@ -13,6 +13,7 @@ test_tupleio(mxnet)
|
|||||||
test_recurrence(mxnet)
|
test_recurrence(mxnet)
|
||||||
test_stacktrace(mxnet)
|
test_stacktrace(mxnet)
|
||||||
test_back(mxnet)
|
test_back(mxnet)
|
||||||
|
test_anon(mxnet)
|
||||||
|
|
||||||
@testset "Native interface" begin
|
@testset "Native interface" begin
|
||||||
f = mx.FeedForward(Chain(d, softmax))
|
f = mx.FeedForward(Chain(d, softmax))
|
||||||
|
@ -12,6 +12,7 @@ dt = tf(d)
|
|||||||
test_tupleio(tf)
|
test_tupleio(tf)
|
||||||
test_recurrence(tf)
|
test_recurrence(tf)
|
||||||
test_stacktrace(tf)
|
test_stacktrace(tf)
|
||||||
|
test_anon(tf)
|
||||||
|
|
||||||
@testset "Tensor interface" begin
|
@testset "Tensor interface" begin
|
||||||
sess = TensorFlow.Session()
|
sess = TensorFlow.Session()
|
||||||
|
@ -18,6 +18,8 @@ Flux.infer(d, (1, 10))
|
|||||||
# @test isa(x, DataFlow.Input) && isa(W, Param) && isa(b, Param)
|
# @test isa(x, DataFlow.Input) && isa(W, Param) && isa(b, Param)
|
||||||
# end
|
# end
|
||||||
|
|
||||||
|
test_anon(identity)
|
||||||
|
|
||||||
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
||||||
tlp = TLP(a1, a2)
|
tlp = TLP(a1, a2)
|
||||||
@test tlp(xs) ≈ softmax(a2(σ(a1(xs))))
|
@test tlp(xs) ≈ softmax(a2(σ(a1(xs))))
|
||||||
|
Loading…
Reference in New Issue
Block a user