Merge remote-tracking branch 'upstream/master' into add-more-tf-ops-2
This commit is contained in:
commit
e6db3b0e89
@ -59,6 +59,7 @@ end
|
|||||||
rawbatch(xs::Batchable) = rawbatch(storage(xs))
|
rawbatch(xs::Batchable) = rawbatch(storage(xs))
|
||||||
size(xs::Batchable) = size(storage(xs))
|
size(xs::Batchable) = size(storage(xs))
|
||||||
getindex(xs::Batchable, i) = getindex(storage(xs), i)
|
getindex(xs::Batchable, i) = getindex(storage(xs), i)
|
||||||
|
setindex!(xs::Batchable, v, i...) = setindex!(storage(xs), v, i...)
|
||||||
|
|
||||||
Base.vcat{T<:Batchable}(xs::T, ys::T)::T = vcat(rawbatch(xs), rawbatch(ys))
|
Base.vcat{T<:Batchable}(xs::T, ys::T)::T = vcat(rawbatch(xs), rawbatch(ys))
|
||||||
|
|
||||||
@ -84,7 +85,8 @@ dimdec(T::Type{<:AbstractArray}) = deparam(T){eltype(T),ndims(T)-1}
|
|||||||
|
|
||||||
btype(B::Type, S::Type{<:AbstractArray}) = B
|
btype(B::Type, S::Type{<:AbstractArray}) = B
|
||||||
btype(B::Type{<:Batchable}, S::Type{<:AbstractArray}) = B{dimdec(S),S}
|
btype(B::Type{<:Batchable}, S::Type{<:AbstractArray}) = B{dimdec(S),S}
|
||||||
btype(B::Type{<:Batchable{T}} where T, S::Type{<:AbstractArray}) = B{S}
|
btype{T}(B::Type{<:Batchable{T}}, S::Type{<:AbstractArray}) = B{S}
|
||||||
|
btype{T,S<:AbstractArray}(B::Type{<:Batchable{T,S}}, ::Type{S}) = B
|
||||||
btype(B::Type{<:Batchable{<:Batchable}}, S::Type{<:AbstractArray}) =
|
btype(B::Type{<:Batchable{<:Batchable}}, S::Type{<:AbstractArray}) =
|
||||||
deparam(B){btype(eltype(B), dimdec(S)),S}
|
deparam(B){btype(eltype(B), dimdec(S)),S}
|
||||||
|
|
||||||
|
@ -1,4 +1,10 @@
|
|||||||
import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length
|
# Simple version
|
||||||
|
|
||||||
|
using Base.Iterators: partition
|
||||||
|
|
||||||
|
partitionr(xs, n) = take(partition(xs, n), length(xs)÷n)
|
||||||
|
|
||||||
|
batches(xs, n) = (Batch([xs...]) for xs in partitionr(xs, n))
|
||||||
|
|
||||||
# Stateful iteration
|
# Stateful iteration
|
||||||
|
|
||||||
@ -40,6 +46,8 @@ end
|
|||||||
|
|
||||||
# Batched
|
# Batched
|
||||||
|
|
||||||
|
import Base: start, next, done, iteratorsize, iteratoreltype, eltype, length
|
||||||
|
|
||||||
struct Batched{I<:StatefulIter,S}
|
struct Batched{I<:StatefulIter,S}
|
||||||
itr::I
|
itr::I
|
||||||
buf::S
|
buf::S
|
||||||
@ -48,7 +56,9 @@ end
|
|||||||
function Batched(itr, n::Integer)
|
function Batched(itr, n::Integer)
|
||||||
n >= 1 || throw(ArgumentError("batch size must be >= 1"))
|
n >= 1 || throw(ArgumentError("batch size must be >= 1"))
|
||||||
itr = StatefulIter(itr)
|
itr = StatefulIter(itr)
|
||||||
buf = convert(Batch, similar(eltype(itr)(), n, size(peek(itr))...))
|
x = peek(itr)
|
||||||
|
buf = convert(Batch{typeof(peek(itr))},
|
||||||
|
similar(rawbatch(x), n, size(rawbatch(x))...))
|
||||||
Batched(itr, buf)
|
Batched(itr, buf)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -65,6 +75,8 @@ next(x::Batched, _) = x.buf, ()
|
|||||||
function done(x::Batched, _)
|
function done(x::Batched, _)
|
||||||
next = taken!(x.itr, length(x.buf))
|
next = taken!(x.itr, length(x.buf))
|
||||||
length(next) < length(x.buf) && return true
|
length(next) < length(x.buf) && return true
|
||||||
x.buf[:] = next
|
for (i, n) in enumerate(next)
|
||||||
|
x.buf[i] = rawbatch(n)
|
||||||
|
end
|
||||||
return false
|
return false
|
||||||
end
|
end
|
||||||
|
@ -145,7 +145,7 @@ end
|
|||||||
function FeedForward(model; input = :data, label = :softmax, ctx = mx.cpu())
|
function FeedForward(model; input = :data, label = :softmax, ctx = mx.cpu())
|
||||||
model = rewrite_softmax(model, label)
|
model = rewrite_softmax(model, label)
|
||||||
graph = tograph(model, input, feedforward=true)
|
graph = tograph(model, input, feedforward=true)
|
||||||
ff = mx.FeedForward(graph.output, context = context)
|
ff = mx.FeedForward(graph.output, context = ctx)
|
||||||
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params, ctx)))
|
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params, ctx)))
|
||||||
return ff
|
return ff
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user