get rid of Group
This commit is contained in:
parent
d86225ca47
commit
3c068744d2
|
@ -2,7 +2,7 @@ module Flux
|
|||
|
||||
using MacroTools, Lazy, DataFlow, Juno
|
||||
using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
|
||||
iscyclic, Constant, constant, isconstant, Group, group, Split, splitnode,
|
||||
iscyclic, Constant, constant, isconstant, group, Split, splitnode,
|
||||
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
|
||||
spliceinputs, bumpinputs
|
||||
using Juno: Tree, Row
|
||||
|
|
|
@ -1,21 +1,23 @@
|
|||
using Base: @get!
|
||||
using DataFlow: Constant, constant, Context, interpret, interptuple
|
||||
using DataFlow: Constant, constant, Context, interpret, Split
|
||||
using TensorFlow: RawTensor
|
||||
|
||||
# TODO: implement Julia's type promotion rules
|
||||
|
||||
node(x::Tensor) = x
|
||||
node(x::Tuple) = map(node, x)
|
||||
node(x::Tensor) = x
|
||||
node(x::Number) = TensorFlow.constant(Float32(x))
|
||||
|
||||
graph(::typeof(*), args...) = *(args...)
|
||||
graph(::typeof(.*), args...) = .*(args...)
|
||||
graph(::typeof(.-), args...) = -(args...)
|
||||
graph(::typeof(+), args...) = +(args...)
|
||||
graph(::typeof(tuple), args...) = (args...,)
|
||||
graph(s::Split, t::Tuple) = t[s.n]
|
||||
graph(::typeof(softmax), x) = nn.softmax(x)
|
||||
graph(::typeof(relu), x) = nn.relu(x)
|
||||
graph(::typeof(tanh), x) = tanh(x)
|
||||
graph(::typeof(σ), x) = nn.sigmoid(x)
|
||||
graph(::typeof(.+), args...) = +(args...)
|
||||
|
||||
for op in (tanh, *, .*, +, -, .-)
|
||||
@eval graph(::typeof($op), args...) = $op(args...)
|
||||
end
|
||||
|
||||
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
||||
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
||||
|
@ -47,7 +49,7 @@ function interp(ctx, model, args...)
|
|||
end
|
||||
|
||||
function tograph(model, args...)
|
||||
ctx = Context(interptuple(interp), params = ObjectIdDict())
|
||||
ctx = Context(interp, params = ObjectIdDict())
|
||||
out = interp(ctx, model, map(constant, args)...)
|
||||
return ctx[:params], out
|
||||
end
|
||||
|
|
|
@ -2,7 +2,7 @@ using DataFlow: interpret, interpret, interptuple, interplambda, interpconst, Co
|
|||
|
||||
function astuple(xs)
|
||||
isconstant(xs) && isa(value(xs).value, Tuple) ? value(xs).value :
|
||||
isa(xs, Vertex) && isa(value(xs), Group) ? inputs(xs) :
|
||||
isa(xs, Vertex) && value(xs) == tuple ? inputs(xs) :
|
||||
nothing
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue