From 9c8dbb6b4b7bb162057f82a64f90eca5c81ffec3 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 9 Jun 2017 18:54:35 +0100 Subject: [PATCH 1/3] feedforward fix --- src/backend/mxnet/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 6a09d44f..2488e3c1 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -145,7 +145,7 @@ end function FeedForward(model; input = :data, label = :softmax, ctx = mx.cpu()) model = rewrite_softmax(model, label) graph = tograph(model, input, feedforward=true) - ff = mx.FeedForward(graph.output, context = context) + ff = mx.FeedForward(graph.output, context = ctx) isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params, ctx))) return ff end From 65400f20aba444083605d94644f8bcc3a172ac99 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 9 Jun 2017 18:55:21 +0100 Subject: [PATCH 2/3] nested batch tweaks --- src/Batches/catmat.jl | 4 +++- src/Batches/iter.jl | 8 ++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/Batches/catmat.jl b/src/Batches/catmat.jl index 0289b5f2..dcf336c1 100644 --- a/src/Batches/catmat.jl +++ b/src/Batches/catmat.jl @@ -59,6 +59,7 @@ end rawbatch(xs::Batchable) = rawbatch(storage(xs)) size(xs::Batchable) = size(storage(xs)) getindex(xs::Batchable, i) = getindex(storage(xs), i) +setindex!(xs::Batchable, v, i...) = setindex!(storage(xs), v, i...) Base.vcat{T<:Batchable}(xs::T, ys::T)::T = vcat(rawbatch(xs), rawbatch(ys)) @@ -84,7 +85,8 @@ dimdec(T::Type{<:AbstractArray}) = deparam(T){eltype(T),ndims(T)-1} btype(B::Type, S::Type{<:AbstractArray}) = B btype(B::Type{<:Batchable}, S::Type{<:AbstractArray}) = B{dimdec(S),S} -btype(B::Type{<:Batchable{T}} where T, S::Type{<:AbstractArray}) = B{S} +btype{T}(B::Type{<:Batchable{T}}, S::Type{<:AbstractArray}) = B{S} +btype{T,S<:AbstractArray}(B::Type{<:Batchable{T,S}}, ::Type{S}) = B btype(B::Type{<:Batchable{<:Batchable}}, S::Type{<:AbstractArray}) = deparam(B){btype(eltype(B), dimdec(S)),S} diff --git a/src/Batches/iter.jl b/src/Batches/iter.jl index 051541e6..70673739 100644 --- a/src/Batches/iter.jl +++ b/src/Batches/iter.jl @@ -48,7 +48,9 @@ end function Batched(itr, n::Integer) n >= 1 || throw(ArgumentError("batch size must be >= 1")) itr = StatefulIter(itr) - buf = convert(Batch, similar(eltype(itr)(), n, size(peek(itr))...)) + x = peek(itr) + buf = convert(Batch{typeof(peek(itr))}, + similar(rawbatch(x), n, size(rawbatch(x))...)) Batched(itr, buf) end @@ -65,6 +67,8 @@ next(x::Batched, _) = x.buf, () function done(x::Batched, _) next = taken!(x.itr, length(x.buf)) length(next) < length(x.buf) && return true - x.buf[:] = next + for (i, n) in enumerate(next) + x.buf[i] = rawbatch(n) + end return false end From 358ba650adacdb2e67d1e179044978e7bbf467b3 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 9 Jun 2017 18:57:18 +0100 Subject: [PATCH 3/3] more robust `batches` --- src/Batches/iter.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/Batches/iter.jl b/src/Batches/iter.jl index 70673739..27fd6c26 100644 --- a/src/Batches/iter.jl +++ b/src/Batches/iter.jl @@ -1,4 +1,10 @@ -import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length +# Simple version + +using Base.Iterators: partition + +partitionr(xs, n) = take(partition(xs, n), length(xs)÷n) + +batches(xs, n) = (Batch([xs...]) for xs in partitionr(xs, n)) # Stateful iteration @@ -40,6 +46,8 @@ end # Batched +import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length + struct Batched{I<:StatefulIter,S} itr::I buf::S