update mxnet backend
This commit is contained in:
parent
cf2b168a55
commit
79981d6415
@ -1 +1,3 @@
|
|||||||
abstract Activation <: Model
|
export σ
|
||||||
|
|
||||||
|
σ(x) = 1 ./ (1 .+ exp.(-x))
|
||||||
|
@ -7,7 +7,8 @@ graph(vars, model, args...) = node(model, args...)
|
|||||||
graph(vars, x::mx.SymbolicNode) = x
|
graph(vars, x::mx.SymbolicNode) = x
|
||||||
|
|
||||||
# TODO: detect parameters used more than once
|
# TODO: detect parameters used more than once
|
||||||
function graph(vars, value::AArray)
|
function graph{T<:AArray}(vars, p::Flux.Param{T})
|
||||||
|
value = p.x
|
||||||
id = gensym()
|
id = gensym()
|
||||||
vars[id] = value
|
vars[id] = value
|
||||||
return mx.Variable(id)
|
return mx.Variable(id)
|
||||||
@ -28,3 +29,4 @@ end
|
|||||||
|
|
||||||
node(::typeof(*), args...) = mx.dot(args...)
|
node(::typeof(*), args...) = mx.dot(args...)
|
||||||
node(::typeof(+), args...) = mx.broadcast_plus(args...)
|
node(::typeof(+), args...) = mx.broadcast_plus(args...)
|
||||||
|
node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
||||||
|
@ -7,6 +7,8 @@ end
|
|||||||
mxdims(dims::NTuple) =
|
mxdims(dims::NTuple) =
|
||||||
length(dims) == 1 ? (1, dims...) : reverse(dims)
|
length(dims) == 1 ? (1, dims...) : reverse(dims)
|
||||||
|
|
||||||
|
mxdims(n::Integer) = mxdims((n,))
|
||||||
|
|
||||||
function mxargs(args)
|
function mxargs(args)
|
||||||
map(args) do kv
|
map(args) do kv
|
||||||
arg, value = kv
|
arg, value = kv
|
||||||
@ -37,13 +39,3 @@ function (model::MXModel)(input)
|
|||||||
mx.forward(model.exec)
|
mx.forward(model.exec)
|
||||||
copy(model.exec.outputs[1])'
|
copy(model.exec.outputs[1])'
|
||||||
end
|
end
|
||||||
|
|
||||||
# d = Dense(20, 10)
|
|
||||||
|
|
||||||
# x = randn(20)
|
|
||||||
|
|
||||||
# model = mxnet(d, (20,))
|
|
||||||
|
|
||||||
# d(x)
|
|
||||||
|
|
||||||
# model(x)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user