diff --git a/src/utils.jl b/src/utils.jl index d25f8e2b..347ff202 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,7 +9,7 @@ const AArray = AbstractArray initn(dims...) = randn(dims...)/100 unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) -Base.squeeze(xs) = squeeze(xs, 1) +squeeze(xs, dim = 1) = Base.squeeze(xs, dim) stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...) unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] diff --git a/test/runtests.jl b/test/runtests.jl index e110725e..4a86c6e5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Flux, DataFlow, MacroTools, Base.Test -using Flux: graph, Param, unsqueeze +using Flux: graph, Param, squeeze, unsqueeze using DataFlow: Line, Frame macro mxonly(ex)