use batching api

This commit is contained in:
Mike J Innes 2016-10-28 17:00:31 +01:00
parent 2852dddf0f
commit c5a64391a1
1 changed files with 11 additions and 4 deletions

View File

@ -17,14 +17,21 @@ function tf(model)
[gradients(output, input)])
end
batch(x) = Batch((x,))
batchone(x) = Batch((x,))
function batch(xs)
dims = ndims(xs)-1
T = Array{eltype(xs),dims}
B = Array{eltype(xs),dims+1}
Batch{T,B}(xs)
end
function (m::Model)(args::Batch...)
@assert length(args) == length(m.inputs)
run(m.session, m.outputs[1], Dict(zip(m.inputs, args)))
batch(run(m.session, m.outputs[1], Dict(zip(m.inputs, args))))
end
(m::Model)(args...) = m(map(batch, args)...)
(m::Model)(args...) = first(m(map(batchone, args)...))
function Flux.back!(m::Model, Δ, args...)
@assert length(args) == length(m.inputs)
@ -49,7 +56,7 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
info("Epoch $e\n")
@progress for (x, y) in train
y, cur_loss, _ = run(m.session, vcat(m.outputs[1], Loss, minimize_op),
Dict(m.inputs[1]=>batch(x), Y=>batch(y)))
Dict(m.inputs[1]=>batchone(x), Y=>batchone(y)))
if i % 5000 == 0
@show y
@show accuracy(m, test)