Tiny bugfix: stack
was still calling julia 0.6 cat
Also added tiny test for good measure.
This commit is contained in:
parent
9781f063aa
commit
9b897fc601
@ -10,7 +10,7 @@ zeros(dims...) = Base.zeros(Float32, dims...)
|
|||||||
|
|
||||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||||
|
|
||||||
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...)
|
stack(xs, dim) = cat(unsqueeze.(xs, dim)...; dims=dim)
|
||||||
unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
using Flux
|
using Flux
|
||||||
using Flux: throttle, jacobian, glorot_uniform, glorot_normal
|
using Flux: throttle, jacobian, glorot_uniform, glorot_normal, stack
|
||||||
using StatsBase: std
|
using StatsBase: std
|
||||||
using Random
|
using Random
|
||||||
using Test
|
using Test
|
||||||
@ -86,3 +86,9 @@ end
|
|||||||
m = RNN(10, 5)
|
m = RNN(10, 5)
|
||||||
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Basic" begin
|
||||||
|
x = randn(3,3)
|
||||||
|
stacked = stack([x, x], 2)
|
||||||
|
@test size(stacked) == (3,2,3)
|
||||||
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user