diff --git a/src/dims/batching.jl b/src/dims/batching.jl index 8b9da07a..a3ae6ed8 100644 --- a/src/dims/batching.jl +++ b/src/dims/batching.jl @@ -27,21 +27,3 @@ end convertel(T::Type, xs::Batch) = eltype(eltype(xs)) isa T ? xs : Batch(map(x->convertel(T, x), xs)) - -# Add batching semantics to functions operating on raw arrays -# TODO: remove this in favour of full batching semantics - -mapt(f, x) = f(x) -mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs) - -batchone(x) = Batch((x,)) -batchone(x::Batch) = x - -function unbatchone(xs::Batch) - @assert length(xs) == 1 - return first(xs) -end - -isbatched(x) = false -isbatched(x::Batch) = true -isbatched(xs::Tuple) = any(isbatched, xs) diff --git a/src/utils.jl b/src/utils.jl index d1b81347..bae5736c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,6 +2,9 @@ export AArray const AArray = AbstractArray +mapt(f, x) = f(x) +mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs) + initn(dims...) = randn(dims...)/100 function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)