organise batching utils
This commit is contained in:
parent
90616a3c5a
commit
568b8d7e48
@ -9,6 +9,10 @@ using Juno: Tree, Row
|
|||||||
|
|
||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
|
||||||
|
include("dims/catmat.jl")
|
||||||
|
include("dims/batching.jl")
|
||||||
|
include("dims/seq.jl")
|
||||||
|
|
||||||
include("model.jl")
|
include("model.jl")
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("data.jl")
|
include("data.jl")
|
||||||
@ -25,10 +29,6 @@ include("layers/shape.jl")
|
|||||||
include("layers/chain.jl")
|
include("layers/chain.jl")
|
||||||
include("layers/shims.jl")
|
include("layers/shims.jl")
|
||||||
|
|
||||||
include("dims/catmat.jl")
|
|
||||||
include("dims/batching.jl")
|
|
||||||
include("dims/seq.jl")
|
|
||||||
|
|
||||||
include("cost.jl")
|
include("cost.jl")
|
||||||
|
|
||||||
include("backend/backend.jl")
|
include("backend/backend.jl")
|
||||||
|
@ -29,13 +29,6 @@ storeparams!(m::Model) = storeparams!(m.session, m.params)
|
|||||||
|
|
||||||
ismultioutput(m::Model) = !isa(m.output, Tensor)
|
ismultioutput(m::Model) = !isa(m.output, Tensor)
|
||||||
|
|
||||||
function batch(xs)
|
|
||||||
dims = ndims(xs)-1
|
|
||||||
T = Array{eltype(xs),dims}
|
|
||||||
B = Array{eltype(xs),dims+1}
|
|
||||||
Batch{T,B}(xs)
|
|
||||||
end
|
|
||||||
|
|
||||||
function tferr(model::Model, e)
|
function tferr(model::Model, e)
|
||||||
m = match(r"Node: ([\w\d]+) =", string(e.status))
|
m = match(r"Node: ([\w\d]+) =", string(e.status))
|
||||||
m == nothing && return
|
m == nothing && return
|
||||||
@ -51,7 +44,7 @@ function runmodel(m::Model, args...)
|
|||||||
@assert length(args) == length(m.inputs)
|
@assert length(args) == length(m.inputs)
|
||||||
try
|
try
|
||||||
output = run(m.session, m.output, Dict(zip(m.inputs, args)))
|
output = run(m.session, m.output, Dict(zip(m.inputs, args)))
|
||||||
ismultioutput(m) ? (batch.(output)...,) : batch(output)
|
ismultioutput(m) ? (rebatch.(output)...,) : rebatch(output)
|
||||||
catch e
|
catch e
|
||||||
isa(e, TensorFlow.TFException) || rethrow(e)
|
isa(e, TensorFlow.TFException) || rethrow(e)
|
||||||
tferr(m, e)
|
tferr(m, e)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
module TF
|
module TF
|
||||||
|
|
||||||
using ..Flux, DataFlow, TensorFlow, Juno
|
using ..Flux, DataFlow, TensorFlow, Juno
|
||||||
import Flux: accuracy
|
import Flux: accuracy, rebatch
|
||||||
|
|
||||||
export tf
|
export tf
|
||||||
|
|
||||||
|
@ -11,11 +11,28 @@ Batch(xs) = Batch(CatMat(xs))
|
|||||||
convert{T,S}(::Type{Batch{T,S}},storage::S) =
|
convert{T,S}(::Type{Batch{T,S}},storage::S) =
|
||||||
Batch{T,S}(storage)
|
Batch{T,S}(storage)
|
||||||
|
|
||||||
batchone(x) = Batch((x,))
|
|
||||||
batchone(x::Batch) = x
|
|
||||||
|
|
||||||
@render Juno.Inline b::Batch begin
|
@render Juno.Inline b::Batch begin
|
||||||
Tree(Row(Text("Batch of "), eltype(b),
|
Tree(Row(Text("Batch of "), eltype(b),
|
||||||
Juno.fade("[$(length(b))]")),
|
Juno.fade("[$(length(b))]")),
|
||||||
Juno.trim(collect(b)))
|
Juno.trim(collect(b)))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Convenience methods for batch size 1
|
||||||
|
|
||||||
|
batchone(x) = Batch((x,))
|
||||||
|
batchone(x::Batch) = x
|
||||||
|
batchone(x::Tuple) = map(batchone, x)
|
||||||
|
|
||||||
|
function unbatchone(xs::Batch)
|
||||||
|
@assert length(xs) == 1
|
||||||
|
return first(xs)
|
||||||
|
end
|
||||||
|
|
||||||
|
unbatchone(xs::Tuple) = map(unbatchone, xs)
|
||||||
|
|
||||||
|
function rebatch(xs)
|
||||||
|
dims = ndims(xs)-1
|
||||||
|
T = Array{eltype(xs),dims}
|
||||||
|
B = Array{eltype(xs),dims+1}
|
||||||
|
Batch{T,B}(xs)
|
||||||
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user