From efa51f02e7a7ea28d79aabe496cdb57aedbae4fd Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 11 Oct 2017 11:54:18 +0100 Subject: [PATCH] basic batch type --- src/Flux.jl | 2 ++ src/batches/Batches.jl | 7 +++++++ src/batches/batch.jl | 8 ++++++++ 3 files changed, 17 insertions(+) create mode 100644 src/batches/Batches.jl create mode 100644 src/batches/batch.jl diff --git a/src/Flux.jl b/src/Flux.jl index ff78593f..acefff19 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -31,4 +31,6 @@ include("layers/normalisation.jl") include("data/Data.jl") +include("batches/Batches.jl") + end # module diff --git a/src/batches/Batches.jl b/src/batches/Batches.jl new file mode 100644 index 00000000..066f4d1c --- /dev/null +++ b/src/batches/Batches.jl @@ -0,0 +1,7 @@ +module Batches + +import ..Flux + +include("batch.jl") + +end diff --git a/src/batches/batch.jl b/src/batches/batch.jl new file mode 100644 index 00000000..5a2eb82e --- /dev/null +++ b/src/batches/batch.jl @@ -0,0 +1,8 @@ +struct Batch{T,A,M} + data::A + mask::M +end + +Batch{T}(data, mask) where T = Batch{T,typeof(data),typeof(mask)}(data, mask) + +Batch(xs) = Batch{typeof(first(xs))}(Flux.batch(xs),trues(length(xs)))