mx batch semantics
This commit is contained in:
parent
16d6c9aed9
commit
1f3587e9dc
@ -1,3 +1,4 @@
|
|||||||
|
using Flux: batchone, rebatch
|
||||||
|
|
||||||
type Model <: Flux.Model
|
type Model <: Flux.Model
|
||||||
model::Any
|
model::Any
|
||||||
@ -50,6 +51,10 @@ function runmodel(model::Model, input)
|
|||||||
copy(model.exec.outputs[1])
|
copy(model.exec.outputs[1])
|
||||||
end
|
end
|
||||||
|
|
||||||
|
(m::Model)(x::Batch) = rebatch(runmodel(m, rawbatch(x)))
|
||||||
|
|
||||||
|
(m::Model)(x) = first(m(batchone(x)))
|
||||||
|
|
||||||
function Flux.back!(model::Model, Δ, x)
|
function Flux.back!(model::Model, Δ, x)
|
||||||
ndzero!(model.grads[:input])
|
ndzero!(model.grads[:input])
|
||||||
mx.backward(model.exec, tond(Δ))
|
mx.backward(model.exec, tond(Δ))
|
||||||
|
@ -5,9 +5,8 @@ let dt = tf(d)
|
|||||||
@test d(xs) ≈ dt(xs)
|
@test d(xs) ≈ dt(xs)
|
||||||
end
|
end
|
||||||
|
|
||||||
# TODO: batch semantics
|
|
||||||
let dm = mxnet(d, (1, 20))
|
let dm = mxnet(d, (1, 20))
|
||||||
@test d(xs)' ≈ dm(xs')
|
@test d(xs) ≈ dm(xs)
|
||||||
end
|
end
|
||||||
|
|
||||||
# TensorFlow native integration
|
# TensorFlow native integration
|
||||||
|
Loading…
Reference in New Issue
Block a user