mx batch semantics

This commit is contained in:
Mike J Innes 2017-01-30 23:35:15 +05:30
parent 16d6c9aed9
commit 1f3587e9dc
2 changed files with 6 additions and 2 deletions

View File

@ -1,3 +1,4 @@
using Flux: batchone, rebatch
type Model <: Flux.Model
model::Any
@ -50,6 +51,10 @@ function runmodel(model::Model, input)
copy(model.exec.outputs[1])
end
(m::Model)(x::Batch) = rebatch(runmodel(m, rawbatch(x)))
(m::Model)(x) = first(m(batchone(x)))
function Flux.back!(model::Model, Δ, x)
ndzero!(model.grads[:input])
mx.backward(model.exec, tond(Δ))

View File

@ -5,9 +5,8 @@ let dt = tf(d)
@test d(xs) dt(xs)
end
# TODO: batch semantics
let dm = mxnet(d, (1, 20))
@test d(xs)' dm(xs')
@test d(xs) dm(xs)
end
# TensorFlow native integration