diff --git a/src/Flux.jl b/src/Flux.jl index 1df87404..b0a04624 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -9,6 +9,10 @@ using Juno: Tree, Row # Zero Flux Given +include("dims/catmat.jl") +include("dims/batching.jl") +include("dims/seq.jl") + include("model.jl") include("utils.jl") include("data.jl") @@ -25,10 +29,6 @@ include("layers/shape.jl") include("layers/chain.jl") include("layers/shims.jl") -include("dims/catmat.jl") -include("dims/batching.jl") -include("dims/seq.jl") - include("cost.jl") include("backend/backend.jl") diff --git a/src/backend/tensorflow/model.jl b/src/backend/tensorflow/model.jl index ef6d2040..6ad1abe3 100644 --- a/src/backend/tensorflow/model.jl +++ b/src/backend/tensorflow/model.jl @@ -29,13 +29,6 @@ storeparams!(m::Model) = storeparams!(m.session, m.params) ismultioutput(m::Model) = !isa(m.output, Tensor) -function batch(xs) - dims = ndims(xs)-1 - T = Array{eltype(xs),dims} - B = Array{eltype(xs),dims+1} - Batch{T,B}(xs) -end - function tferr(model::Model, e) m = match(r"Node: ([\w\d]+) =", string(e.status)) m == nothing && return @@ -51,7 +44,7 @@ function runmodel(m::Model, args...) @assert length(args) == length(m.inputs) try output = run(m.session, m.output, Dict(zip(m.inputs, args))) - ismultioutput(m) ? (batch.(output)...,) : batch(output) + ismultioutput(m) ? (rebatch.(output)...,) : rebatch(output) catch e isa(e, TensorFlow.TFException) || rethrow(e) tferr(m, e) diff --git a/src/backend/tensorflow/tensorflow.jl b/src/backend/tensorflow/tensorflow.jl index f2c27f2b..7551fe56 100644 --- a/src/backend/tensorflow/tensorflow.jl +++ b/src/backend/tensorflow/tensorflow.jl @@ -1,7 +1,7 @@ module TF using ..Flux, DataFlow, TensorFlow, Juno -import Flux: accuracy +import Flux: accuracy, rebatch export tf diff --git a/src/dims/batching.jl b/src/dims/batching.jl index 8faf50f6..24dfc85d 100644 --- a/src/dims/batching.jl +++ b/src/dims/batching.jl @@ -11,11 +11,28 @@ Batch(xs) = Batch(CatMat(xs)) convert{T,S}(::Type{Batch{T,S}},storage::S) = Batch{T,S}(storage) -batchone(x) = Batch((x,)) -batchone(x::Batch) = x - @render Juno.Inline b::Batch begin Tree(Row(Text("Batch of "), eltype(b), Juno.fade("[$(length(b))]")), Juno.trim(collect(b))) end + +# Convenience methods for batch size 1 + +batchone(x) = Batch((x,)) +batchone(x::Batch) = x +batchone(x::Tuple) = map(batchone, x) + +function unbatchone(xs::Batch) + @assert length(xs) == 1 + return first(xs) +end + +unbatchone(xs::Tuple) = map(unbatchone, xs) + +function rebatch(xs) + dims = ndims(xs)-1 + T = Array{eltype(xs),dims} + B = Array{eltype(xs),dims+1} + Batch{T,B}(xs) +end