diff --git a/src/Flux.jl b/src/Flux.jl index ff78593f..acefff19 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -31,4 +31,6 @@ include("layers/normalisation.jl") include("data/Data.jl") +include("batches/Batches.jl") + end # module diff --git a/src/batches/Batches.jl b/src/batches/Batches.jl new file mode 100644 index 00000000..066f4d1c --- /dev/null +++ b/src/batches/Batches.jl @@ -0,0 +1,7 @@ +module Batches + +import ..Flux + +include("batch.jl") + +end diff --git a/src/batches/batch.jl b/src/batches/batch.jl new file mode 100644 index 00000000..5a2eb82e --- /dev/null +++ b/src/batches/batch.jl @@ -0,0 +1,8 @@ +struct Batch{T,A,M} + data::A + mask::M +end + +Batch{T}(data, mask) where T = Batch{T,typeof(data),typeof(mask)}(data, mask) + +Batch(xs) = Batch{typeof(first(xs))}(Flux.batch(xs),trues(length(xs)))