mx batch semantics
This commit is contained in:
parent
16d6c9aed9
commit
1f3587e9dc
|
@ -1,3 +1,4 @@
|
|||
using Flux: batchone, rebatch
|
||||
|
||||
type Model <: Flux.Model
|
||||
model::Any
|
||||
|
@ -50,6 +51,10 @@ function runmodel(model::Model, input)
|
|||
copy(model.exec.outputs[1])
|
||||
end
|
||||
|
||||
(m::Model)(x::Batch) = rebatch(runmodel(m, rawbatch(x)))
|
||||
|
||||
(m::Model)(x) = first(m(batchone(x)))
|
||||
|
||||
function Flux.back!(model::Model, Δ, x)
|
||||
ndzero!(model.grads[:input])
|
||||
mx.backward(model.exec, tond(Δ))
|
||||
|
|
|
@ -5,9 +5,8 @@ let dt = tf(d)
|
|||
@test d(xs) ≈ dt(xs)
|
||||
end
|
||||
|
||||
# TODO: batch semantics
|
||||
let dm = mxnet(d, (1, 20))
|
||||
@test d(xs)' ≈ dm(xs')
|
||||
@test d(xs) ≈ dm(xs)
|
||||
end
|
||||
|
||||
# TensorFlow native integration
|
||||
|
|
Loading…
Reference in New Issue