use new batching approach in TensorFlow
This commit is contained in:
parent
46550e4863
commit
a06145a145
@ -2,6 +2,7 @@ module TF
|
|||||||
|
|
||||||
using ..Flux, Flow, TensorFlow, Juno
|
using ..Flux, Flow, TensorFlow, Juno
|
||||||
import Flux: accuracy
|
import Flux: accuracy
|
||||||
|
import TensorFlow: RawTensor
|
||||||
import Juno: info
|
import Juno: info
|
||||||
|
|
||||||
export tf
|
export tf
|
||||||
@ -63,11 +64,17 @@ function tf(model)
|
|||||||
Model(sess, [input], g, gradients(g, input))
|
Model(sess, [input], g, gradients(g, input))
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::Model)(args...)
|
batch(x) = Batch((x,))
|
||||||
|
|
||||||
|
RawTensor(data::Batch) = RawTensor(rawbatch(data))
|
||||||
|
|
||||||
|
function (m::Model)(args::Batch...)
|
||||||
@assert length(args) == length(m.inputs)
|
@assert length(args) == length(m.inputs)
|
||||||
Flux.unbatch(run(m.session, m.graph, Dict(zip(m.inputs, map(batch, args)))))
|
run(m.session, m.graph, Dict(zip(m.inputs, args)))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
(m::Model)(args...) = m(map(batch, args)...)
|
||||||
|
|
||||||
function Flux.back!(m::Model, Δ, args...)
|
function Flux.back!(m::Model, Δ, args...)
|
||||||
@assert length(args) == length(m.inputs)
|
@assert length(args) == length(m.inputs)
|
||||||
# TODO: keyword arguments to `gradients`
|
# TODO: keyword arguments to `gradients`
|
||||||
|
@ -1,8 +1,4 @@
|
|||||||
export batch, Batch
|
export Batch
|
||||||
|
|
||||||
# TODO: support the Batch type only
|
|
||||||
batch(x) = reshape(x, (1,size(x)...))
|
|
||||||
batch(xs...) = vcat(map(batch, xs)...)
|
|
||||||
|
|
||||||
immutable Batch{T,S} <: AbstractVector{T}
|
immutable Batch{T,S} <: AbstractVector{T}
|
||||||
data::CatMat{T,S}
|
data::CatMat{T,S}
|
||||||
|
Loading…
Reference in New Issue
Block a user