update mxnet backend

This commit is contained in:
Mike J Innes 2016-08-23 14:28:54 +01:00
parent cf2b168a55
commit 79981d6415
3 changed files with 8 additions and 12 deletions

View File

@ -1 +1,3 @@
abstract Activation <: Model
export σ
σ(x) = 1 ./ (1 .+ exp.(-x))

View File

@ -7,7 +7,8 @@ graph(vars, model, args...) = node(model, args...)
graph(vars, x::mx.SymbolicNode) = x
# TODO: detect parameters used more than once
function graph(vars, value::AArray)
function graph{T<:AArray}(vars, p::Flux.Param{T})
value = p.x
id = gensym()
vars[id] = value
return mx.Variable(id)
@ -28,3 +29,4 @@ end
node(::typeof(*), args...) = mx.dot(args...)
node(::typeof(+), args...) = mx.broadcast_plus(args...)
node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)

View File

@ -7,6 +7,8 @@ end
mxdims(dims::NTuple) =
length(dims) == 1 ? (1, dims...) : reverse(dims)
mxdims(n::Integer) = mxdims((n,))
function mxargs(args)
map(args) do kv
arg, value = kv
@ -37,13 +39,3 @@ function (model::MXModel)(input)
mx.forward(model.exec)
copy(model.exec.outputs[1])'
end
# d = Dense(20, 10)
# x = randn(20)
# model = mxnet(d, (20,))
# d(x)
# model(x)