Tiny bugfix: `stack` was still calling julia 0.6 `cat`
Also added tiny test for good measure.
This commit is contained in:
parent
9781f063aa
commit
9b897fc601
|
@ -10,7 +10,7 @@ zeros(dims...) = Base.zeros(Float32, dims...)
|
|||
|
||||
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||
|
||||
stack(xs, dim) = cat(dim, unsqueeze.(xs, dim)...)
|
||||
stack(xs, dim) = cat(unsqueeze.(xs, dim)...; dims=dim)
|
||||
unstack(xs, dim) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
||||
|
||||
"""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
using Flux
|
||||
using Flux: throttle, jacobian, glorot_uniform, glorot_normal
|
||||
using Flux: throttle, jacobian, glorot_uniform, glorot_normal, stack
|
||||
using StatsBase: std
|
||||
using Random
|
||||
using Test
|
||||
|
@ -86,3 +86,9 @@ end
|
|||
m = RNN(10, 5)
|
||||
@test size.(params(m)) == [(5, 10), (5, 5), (5,), (5,)]
|
||||
end
|
||||
|
||||
@testset "Basic" begin
|
||||
x = randn(3,3)
|
||||
stacked = stack([x, x], 2)
|
||||
@test size(stacked) == (3,2,3)
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue