Use a package-local squeeze function instead of extending Base

This commit is contained in:
Tony Kelman 2017-05-22 04:08:46 -04:00
parent 5cbb47a13d
commit 41ea071f3a
3 changed files with 3 additions and 2 deletions

View File

@ -9,6 +9,7 @@ const AArray = AbstractArray
initn(dims...) = randn(dims...)/100
unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
squeeze(xs, dim = 1) = Base.squeeze(xs, dim)
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]

View File

@ -13,5 +13,5 @@ end
_, ys = apply(unroll1(r).model, xs, (r.y.x,))
@test ys[1] == tanh(xs[1] * r.Wxy.x .+ r.y.x * r.Wyy.x .+ r.by.x)
ru = unroll(r, 3)
ru(batchone(Seq(squeeze.(xs, 1))))[1] == squeeze.(ys, 1)
ru(batchone(Seq(squeeze.(xs))))[1] == squeeze.(ys)
end

View File

@ -1,5 +1,5 @@
using Flux, DataFlow, MacroTools, Base.Test
using Flux: graph, Param, unsqueeze
using Flux: graph, Param, squeeze, unsqueeze
using DataFlow: Line, Frame
macro mxonly(ex)