2016-10-28 15:06:56 +00:00
|
|
|
|
import Base: @get!
|
2016-10-31 12:38:18 +00:00
|
|
|
|
import DataFlow: Constant, postwalk, value, inputs, constant
|
2016-10-26 13:25:10 +00:00
|
|
|
|
import TensorFlow: RawTensor
|
|
|
|
|
|
2016-11-02 00:36:13 +00:00
|
|
|
|
# TODO: implement Julia's type promotion rules
|
|
|
|
|
|
2016-10-26 13:25:10 +00:00
|
|
|
|
cvalue(x) = x
|
|
|
|
|
cvalue(c::Constant) = c.value
|
|
|
|
|
cvalue(v::Vertex) = cvalue(value(v))
|
|
|
|
|
|
|
|
|
|
graph(x::Tensor) = x
|
2016-11-02 00:36:13 +00:00
|
|
|
|
graph(x::Number) = TensorFlow.constant(Float32(x))
|
2016-10-26 13:25:10 +00:00
|
|
|
|
|
|
|
|
|
graph(::typeof(*), args...) = *(args...)
|
2016-10-30 10:54:55 +00:00
|
|
|
|
graph(::typeof(.*), args...) = .*(args...)
|
2016-11-02 00:36:13 +00:00
|
|
|
|
graph(::typeof(.-), args...) = -(args...)
|
2016-10-26 13:25:10 +00:00
|
|
|
|
graph(::typeof(+), args...) = +(args...)
|
|
|
|
|
graph(::typeof(softmax), x) = nn.softmax(x)
|
|
|
|
|
graph(::typeof(relu), x) = nn.relu(x)
|
|
|
|
|
graph(::typeof(tanh), x) = tanh(x)
|
2016-10-28 14:13:43 +00:00
|
|
|
|
graph(::typeof(σ), x) = nn.sigmoid(x)
|
2016-10-26 13:25:10 +00:00
|
|
|
|
|
|
|
|
|
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
|
|
|
|
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
|
|
|
|
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)]))
|
|
|
|
|
graph(r::Reshape, x) = reshape(x, pack([batchsize(x), map(Int32, r.dims)...]))
|
|
|
|
|
|
|
|
|
|
graph(::Input, x) = x
|
|
|
|
|
|
|
|
|
|
graph(p::MaxPool, x) =
|
|
|
|
|
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
|
|
|
|
|
|
2016-10-31 12:38:18 +00:00
|
|
|
|
graph(::DataFlow.Group, xs...) = (xs...,)
|
2016-10-28 14:13:43 +00:00
|
|
|
|
|
2016-10-28 15:06:56 +00:00
|
|
|
|
graph(params::Associative, c::Conv2D, x) =
|
|
|
|
|
nn.conv2d(x, graph(params, c.filter), [1,c.stride...,1], "VALID")
|
|
|
|
|
|
2016-10-26 13:25:10 +00:00
|
|
|
|
graph(op::Op, xs...) = op.f(xs...)
|
|
|
|
|
Flux.shape(op::Op, d...) = op.shape(d...)
|
|
|
|
|
|
2016-10-28 15:06:56 +00:00
|
|
|
|
graph{T<:AArray}(params::Associative, p::Flux.Param{T}) =
|
|
|
|
|
@get!(params, p, Variable(p.x))
|
2016-10-26 13:25:10 +00:00
|
|
|
|
|
2016-10-28 15:06:56 +00:00
|
|
|
|
function graph(params::Associative, v::IVertex, args...)
|
2016-10-28 14:13:58 +00:00
|
|
|
|
# TODO: check number of arguments
|
|
|
|
|
v = spliceinputs(v, map(constant, args)...) |> detuple
|
|
|
|
|
postwalk(v) do v
|
2016-10-28 15:06:56 +00:00
|
|
|
|
vertex(graph(params, cvalue(v), cvalue.(inputs(v))...))
|
2016-10-26 13:25:10 +00:00
|
|
|
|
end |> value
|
|
|
|
|
end
|
|
|
|
|
|
2016-10-28 15:06:56 +00:00
|
|
|
|
function graph(params::Associative, model, args...)
|
2016-10-28 14:13:58 +00:00
|
|
|
|
g = Flux.graph(model)
|
2016-10-28 15:06:56 +00:00
|
|
|
|
g == nothing && return graph(model, args...)
|
2016-10-31 12:38:18 +00:00
|
|
|
|
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
2016-10-28 15:06:56 +00:00
|
|
|
|
graph(params, g, args...)
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
function tograph(model, args...)
|
|
|
|
|
params = Dict{Flux.Param,Tensor}()
|
|
|
|
|
g = graph(params, model, args...)
|
|
|
|
|
return params, g
|
2016-10-28 14:13:58 +00:00
|
|
|
|
end
|
|
|
|
|
|
2016-10-28 15:06:56 +00:00
|
|
|
|
TensorFlow.Tensor(m::Flux.Model, args...) = graph(Dict(), m, args...)
|
2016-10-26 13:25:10 +00:00
|
|
|
|
|
|
|
|
|
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
2016-10-30 12:10:44 +00:00
|
|
|
|
|
|
|
|
|
function makesession(model, n)
|
|
|
|
|
sess = Session(Graph())
|
|
|
|
|
inputs = [placeholder(Float32) for _ = 1:n]
|
|
|
|
|
params, output = tograph(model, inputs...)
|
|
|
|
|
run(sess, initialize_all_variables())
|
|
|
|
|
sess, params, inputs, output
|
|
|
|
|
end
|
2016-10-30 15:08:50 +00:00
|
|
|
|
|
|
|
|
|
function storeparams!(sess, params)
|
|
|
|
|
for (p, t) in params
|
|
|
|
|
p.x = run(sess, t)
|
|
|
|
|
end
|
|
|
|
|
end
|