get rid of Group

This commit is contained in:
Mike J Innes 2016-11-15 21:09:58 +00:00
parent d86225ca47
commit 3c068744d2
3 changed files with 12 additions and 10 deletions

View File

@ -2,7 +2,7 @@ module Flux
using MacroTools, Lazy, DataFlow, Juno
using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
iscyclic, Constant, constant, isconstant, Group, group, Split, splitnode,
iscyclic, Constant, constant, isconstant, group, Split, splitnode,
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
spliceinputs, bumpinputs
using Juno: Tree, Row

View File

@ -1,21 +1,23 @@
using Base: @get!
using DataFlow: Constant, constant, Context, interpret, interptuple
using DataFlow: Constant, constant, Context, interpret, Split
using TensorFlow: RawTensor
# TODO: implement Julia's type promotion rules
node(x::Tensor) = x
node(x::Tuple) = map(node, x)
node(x::Tensor) = x
node(x::Number) = TensorFlow.constant(Float32(x))
graph(::typeof(*), args...) = *(args...)
graph(::typeof(.*), args...) = .*(args...)
graph(::typeof(.-), args...) = -(args...)
graph(::typeof(+), args...) = +(args...)
graph(::typeof(tuple), args...) = (args...,)
graph(s::Split, t::Tuple) = t[s.n]
graph(::typeof(softmax), x) = nn.softmax(x)
graph(::typeof(relu), x) = nn.relu(x)
graph(::typeof(tanh), x) = tanh(x)
graph(::typeof(σ), x) = nn.sigmoid(x)
graph(::typeof(.+), args...) = +(args...)
for op in (tanh, *, .*, +, -, .-)
@eval graph(::typeof($op), args...) = $op(args...)
end
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
@ -47,7 +49,7 @@ function interp(ctx, model, args...)
end
function tograph(model, args...)
ctx = Context(interptuple(interp), params = ObjectIdDict())
ctx = Context(interp, params = ObjectIdDict())
out = interp(ctx, model, map(constant, args)...)
return ctx[:params], out
end

View File

@ -2,7 +2,7 @@ using DataFlow: interpret, interpret, interptuple, interplambda, interpconst, Co
function astuple(xs)
isconstant(xs) && isa(value(xs).value, Tuple) ? value(xs).value :
isa(xs, Vertex) && isa(value(xs), Group) ? inputs(xs) :
isa(xs, Vertex) && value(xs) == tuple ? inputs(xs) :
nothing
end