use new batching approach in TensorFlow
This commit is contained in:
parent
46550e4863
commit
a06145a145
|
@ -2,6 +2,7 @@ module TF
|
|||
|
||||
using ..Flux, Flow, TensorFlow, Juno
|
||||
import Flux: accuracy
|
||||
import TensorFlow: RawTensor
|
||||
import Juno: info
|
||||
|
||||
export tf
|
||||
|
@ -63,11 +64,17 @@ function tf(model)
|
|||
Model(sess, [input], g, gradients(g, input))
|
||||
end
|
||||
|
||||
function (m::Model)(args...)
|
||||
batch(x) = Batch((x,))
|
||||
|
||||
RawTensor(data::Batch) = RawTensor(rawbatch(data))
|
||||
|
||||
function (m::Model)(args::Batch...)
|
||||
@assert length(args) == length(m.inputs)
|
||||
Flux.unbatch(run(m.session, m.graph, Dict(zip(m.inputs, map(batch, args)))))
|
||||
run(m.session, m.graph, Dict(zip(m.inputs, args)))
|
||||
end
|
||||
|
||||
(m::Model)(args...) = m(map(batch, args)...)
|
||||
|
||||
function Flux.back!(m::Model, Δ, args...)
|
||||
@assert length(args) == length(m.inputs)
|
||||
# TODO: keyword arguments to `gradients`
|
||||
|
|
|
@ -1,8 +1,4 @@
|
|||
export batch, Batch
|
||||
|
||||
# TODO: support the Batch type only
|
||||
batch(x) = reshape(x, (1,size(x)...))
|
||||
batch(xs...) = vcat(map(batch, xs)...)
|
||||
export Batch
|
||||
|
||||
immutable Batch{T,S} <: AbstractVector{T}
|
||||
data::CatMat{T,S}
|
||||
|
|
Loading…
Reference in New Issue