basic batch type

This commit is contained in:
Mike J Innes 2017-10-11 11:54:18 +01:00
parent 83cc77c860
commit 173fc5d178
3 changed files with 17 additions and 0 deletions

View File

@ -27,4 +27,6 @@ include("layers/stateless.jl")
include("layers/basic.jl")
include("layers/recurrent.jl")
include("batches/Batches.jl")
end # module

7
src/batches/Batches.jl Normal file
View File

@ -0,0 +1,7 @@
module Batches
import ..Flux
include("batch.jl")
end

8
src/batches/batch.jl Normal file
View File

@ -0,0 +1,8 @@
struct Batch{T,A,M}
data::A
mask::M
end
Batch{T}(data, mask) where T = Batch{T,typeof(data),typeof(mask)}(data, mask)
Batch(xs) = Batch{typeof(first(xs))}(Flux.batch(xs),trues(length(xs)))