basic batch type

This commit is contained in:
Mike J Innes 2017-10-11 11:54:18 +01:00
parent 21ea93ffcd
commit efa51f02e7
3 changed files with 17 additions and 0 deletions

View File

@ -31,4 +31,6 @@ include("layers/normalisation.jl")
include("data/Data.jl")
include("batches/Batches.jl")
end # module

7
src/batches/Batches.jl Normal file
View File

@ -0,0 +1,7 @@
module Batches
import ..Flux
include("batch.jl")
end

8
src/batches/batch.jl Normal file
View File

@ -0,0 +1,8 @@
struct Batch{T,A,M}
data::A
mask::M
end
Batch{T}(data, mask) where T = Batch{T,typeof(data),typeof(mask)}(data, mask)
Batch(xs) = Batch{typeof(first(xs))}(Flux.batch(xs),trues(length(xs)))