update mxnet backend
This commit is contained in:
parent
cf2b168a55
commit
79981d6415
|
@ -1 +1,3 @@
|
|||
abstract Activation <: Model
|
||||
export σ
|
||||
|
||||
σ(x) = 1 ./ (1 .+ exp.(-x))
|
||||
|
|
|
@ -7,7 +7,8 @@ graph(vars, model, args...) = node(model, args...)
|
|||
graph(vars, x::mx.SymbolicNode) = x
|
||||
|
||||
# TODO: detect parameters used more than once
|
||||
function graph(vars, value::AArray)
|
||||
function graph{T<:AArray}(vars, p::Flux.Param{T})
|
||||
value = p.x
|
||||
id = gensym()
|
||||
vars[id] = value
|
||||
return mx.Variable(id)
|
||||
|
@ -28,3 +29,4 @@ end
|
|||
|
||||
node(::typeof(*), args...) = mx.dot(args...)
|
||||
node(::typeof(+), args...) = mx.broadcast_plus(args...)
|
||||
node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
||||
|
|
|
@ -7,6 +7,8 @@ end
|
|||
mxdims(dims::NTuple) =
|
||||
length(dims) == 1 ? (1, dims...) : reverse(dims)
|
||||
|
||||
mxdims(n::Integer) = mxdims((n,))
|
||||
|
||||
function mxargs(args)
|
||||
map(args) do kv
|
||||
arg, value = kv
|
||||
|
@ -37,13 +39,3 @@ function (model::MXModel)(input)
|
|||
mx.forward(model.exec)
|
||||
copy(model.exec.outputs[1])'
|
||||
end
|
||||
|
||||
# d = Dense(20, 10)
|
||||
|
||||
# x = randn(20)
|
||||
|
||||
# model = mxnet(d, (20,))
|
||||
|
||||
# d(x)
|
||||
|
||||
# model(x)
|
||||
|
|
Loading…
Reference in New Issue