diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index 6ccb9d39..f3a2b67b 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -1,4 +1,8 @@ -using ..Flux: stack, unstack +unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) +squeeze(xs, dim = 1) = Base.squeeze(xs, dim) + +stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...) +unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] # Stateful Models diff --git a/test/compiler.jl b/test/compiler.jl index f12584be..e550b14f 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -1,6 +1,5 @@ using DataFlow, MacroTools -using Flux: squeeze, unsqueeze, stack -using Flux.Compiler: @net, graph +using Flux.Compiler: @net, graph, stack, squeeze, unsqueeze using DataFlow: Line, Frame @net type Affine