Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
173fc5d178 |
@ -27,4 +27,6 @@ include("layers/stateless.jl")
|
|||||||
include("layers/basic.jl")
|
include("layers/basic.jl")
|
||||||
include("layers/recurrent.jl")
|
include("layers/recurrent.jl")
|
||||||
|
|
||||||
|
include("batches/Batches.jl")
|
||||||
|
|
||||||
end # module
|
end # module
|
||||||
|
7
src/batches/Batches.jl
Normal file
7
src/batches/Batches.jl
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
module Batches
|
||||||
|
|
||||||
|
import ..Flux
|
||||||
|
|
||||||
|
include("batch.jl")
|
||||||
|
|
||||||
|
end
|
8
src/batches/batch.jl
Normal file
8
src/batches/batch.jl
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
struct Batch{T,A,M}
|
||||||
|
data::A
|
||||||
|
mask::M
|
||||||
|
end
|
||||||
|
|
||||||
|
Batch{T}(data, mask) where T = Batch{T,typeof(data),typeof(mask)}(data, mask)
|
||||||
|
|
||||||
|
Batch(xs) = Batch{typeof(first(xs))}(Flux.batch(xs),trues(length(xs)))
|
Loading…
Reference in New Issue
Block a user