diff --git a/src/activation.jl b/src/activation.jl index 0bc6f466..ab269305 100644 --- a/src/activation.jl +++ b/src/activation.jl @@ -1 +1,3 @@ -abstract Activation <: Model +export σ + +σ(x) = 1 ./ (1 .+ exp.(-x)) diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index b9b35f2f..3d439f35 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -7,7 +7,8 @@ graph(vars, model, args...) = node(model, args...) graph(vars, x::mx.SymbolicNode) = x # TODO: detect parameters used more than once -function graph(vars, value::AArray) +function graph{T<:AArray}(vars, p::Flux.Param{T}) + value = p.x id = gensym() vars[id] = value return mx.Variable(id) @@ -28,3 +29,4 @@ end node(::typeof(*), args...) = mx.dot(args...) node(::typeof(+), args...) = mx.broadcast_plus(args...) +node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 68cc3693..844ea915 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -7,6 +7,8 @@ end mxdims(dims::NTuple) = length(dims) == 1 ? (1, dims...) : reverse(dims) +mxdims(n::Integer) = mxdims((n,)) + function mxargs(args) map(args) do kv arg, value = kv @@ -37,13 +39,3 @@ function (model::MXModel)(input) mx.forward(model.exec) copy(model.exec.outputs[1])' end - -# d = Dense(20, 10) - -# x = randn(20) - -# model = mxnet(d, (20,)) - -# d(x) - -# model(x)