organise batching utils

This commit is contained in:
Mike J Innes 2017-01-24 15:54:30 +05:30
parent 90616a3c5a
commit 568b8d7e48
4 changed files with 26 additions and 16 deletions

View File

@ -9,6 +9,10 @@ using Juno: Tree, Row
# Zero Flux Given
include("dims/catmat.jl")
include("dims/batching.jl")
include("dims/seq.jl")
include("model.jl")
include("utils.jl")
include("data.jl")
@ -25,10 +29,6 @@ include("layers/shape.jl")
include("layers/chain.jl")
include("layers/shims.jl")
include("dims/catmat.jl")
include("dims/batching.jl")
include("dims/seq.jl")
include("cost.jl")
include("backend/backend.jl")

View File

@ -29,13 +29,6 @@ storeparams!(m::Model) = storeparams!(m.session, m.params)
ismultioutput(m::Model) = !isa(m.output, Tensor)
function batch(xs)
dims = ndims(xs)-1
T = Array{eltype(xs),dims}
B = Array{eltype(xs),dims+1}
Batch{T,B}(xs)
end
function tferr(model::Model, e)
m = match(r"Node: ([\w\d]+) =", string(e.status))
m == nothing && return
@ -51,7 +44,7 @@ function runmodel(m::Model, args...)
@assert length(args) == length(m.inputs)
try
output = run(m.session, m.output, Dict(zip(m.inputs, args)))
ismultioutput(m) ? (batch.(output)...,) : batch(output)
ismultioutput(m) ? (rebatch.(output)...,) : rebatch(output)
catch e
isa(e, TensorFlow.TFException) || rethrow(e)
tferr(m, e)

View File

@ -1,7 +1,7 @@
module TF
using ..Flux, DataFlow, TensorFlow, Juno
import Flux: accuracy
import Flux: accuracy, rebatch
export tf

View File

@ -11,11 +11,28 @@ Batch(xs) = Batch(CatMat(xs))
convert{T,S}(::Type{Batch{T,S}},storage::S) =
Batch{T,S}(storage)
batchone(x) = Batch((x,))
batchone(x::Batch) = x
@render Juno.Inline b::Batch begin
Tree(Row(Text("Batch of "), eltype(b),
Juno.fade("[$(length(b))]")),
Juno.trim(collect(b)))
end
# Convenience methods for batch size 1
batchone(x) = Batch((x,))
batchone(x::Batch) = x
batchone(x::Tuple) = map(batchone, x)
function unbatchone(xs::Batch)
@assert length(xs) == 1
return first(xs)
end
unbatchone(xs::Tuple) = map(unbatchone, xs)
function rebatch(xs)
dims = ndims(xs)-1
T = Array{eltype(xs),dims}
B = Array{eltype(xs),dims+1}
Batch{T,B}(xs)
end