organise batching utils
This commit is contained in:
parent
90616a3c5a
commit
568b8d7e48
@ -9,6 +9,10 @@ using Juno: Tree, Row
|
||||
|
||||
# Zero Flux Given
|
||||
|
||||
include("dims/catmat.jl")
|
||||
include("dims/batching.jl")
|
||||
include("dims/seq.jl")
|
||||
|
||||
include("model.jl")
|
||||
include("utils.jl")
|
||||
include("data.jl")
|
||||
@ -25,10 +29,6 @@ include("layers/shape.jl")
|
||||
include("layers/chain.jl")
|
||||
include("layers/shims.jl")
|
||||
|
||||
include("dims/catmat.jl")
|
||||
include("dims/batching.jl")
|
||||
include("dims/seq.jl")
|
||||
|
||||
include("cost.jl")
|
||||
|
||||
include("backend/backend.jl")
|
||||
|
@ -29,13 +29,6 @@ storeparams!(m::Model) = storeparams!(m.session, m.params)
|
||||
|
||||
ismultioutput(m::Model) = !isa(m.output, Tensor)
|
||||
|
||||
function batch(xs)
|
||||
dims = ndims(xs)-1
|
||||
T = Array{eltype(xs),dims}
|
||||
B = Array{eltype(xs),dims+1}
|
||||
Batch{T,B}(xs)
|
||||
end
|
||||
|
||||
function tferr(model::Model, e)
|
||||
m = match(r"Node: ([\w\d]+) =", string(e.status))
|
||||
m == nothing && return
|
||||
@ -51,7 +44,7 @@ function runmodel(m::Model, args...)
|
||||
@assert length(args) == length(m.inputs)
|
||||
try
|
||||
output = run(m.session, m.output, Dict(zip(m.inputs, args)))
|
||||
ismultioutput(m) ? (batch.(output)...,) : batch(output)
|
||||
ismultioutput(m) ? (rebatch.(output)...,) : rebatch(output)
|
||||
catch e
|
||||
isa(e, TensorFlow.TFException) || rethrow(e)
|
||||
tferr(m, e)
|
||||
|
@ -1,7 +1,7 @@
|
||||
module TF
|
||||
|
||||
using ..Flux, DataFlow, TensorFlow, Juno
|
||||
import Flux: accuracy
|
||||
import Flux: accuracy, rebatch
|
||||
|
||||
export tf
|
||||
|
||||
|
@ -11,11 +11,28 @@ Batch(xs) = Batch(CatMat(xs))
|
||||
convert{T,S}(::Type{Batch{T,S}},storage::S) =
|
||||
Batch{T,S}(storage)
|
||||
|
||||
batchone(x) = Batch((x,))
|
||||
batchone(x::Batch) = x
|
||||
|
||||
@render Juno.Inline b::Batch begin
|
||||
Tree(Row(Text("Batch of "), eltype(b),
|
||||
Juno.fade("[$(length(b))]")),
|
||||
Juno.trim(collect(b)))
|
||||
end
|
||||
|
||||
# Convenience methods for batch size 1
|
||||
|
||||
batchone(x) = Batch((x,))
|
||||
batchone(x::Batch) = x
|
||||
batchone(x::Tuple) = map(batchone, x)
|
||||
|
||||
function unbatchone(xs::Batch)
|
||||
@assert length(xs) == 1
|
||||
return first(xs)
|
||||
end
|
||||
|
||||
unbatchone(xs::Tuple) = map(unbatchone, xs)
|
||||
|
||||
function rebatch(xs)
|
||||
dims = ndims(xs)-1
|
||||
T = Array{eltype(xs),dims}
|
||||
B = Array{eltype(xs),dims+1}
|
||||
Batch{T,B}(xs)
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user