Flux.jl/src/backend/tensorflow/tensorflow.jl

119 lines
2.9 KiB
Julia
Raw Normal View History

2016-09-27 01:16:49 +00:00
module TF
2016-10-10 22:04:26 +00:00
using ..Flux, Flow, TensorFlow, Juno
2016-09-29 19:50:43 +00:00
import Flux: accuracy
2016-10-10 22:04:26 +00:00
import Juno: info
2016-09-29 19:50:43 +00:00
export tf
2016-09-27 01:16:49 +00:00
cvalue(x) = x
cvalue(c::Constant) = c.value
cvalue(v::Vertex) = cvalue(value(v))
graph(x::Tensor) = x
# TODO: detect variable reuse
2016-10-04 20:10:50 +00:00
graph{T<:AArray}(p::Flux.Param{T}) = Variable(p.x)
2016-09-27 01:16:49 +00:00
function graph(model::Model, args...)
g = Flux.graph(model)
2016-10-04 21:23:37 +00:00
g nothing || error("No graph for $model")
2016-09-27 01:16:49 +00:00
g = Flow.mapconst(g) do x
!isa(x, Flux.ModelInput) ? x :
isa(x.name, Integer) ? args[x.name] : getfield(model, x.name)
end
postwalk(g) do v
vertex(graph(cvalue(v), cvalue.(inputs(v))...))
end |> value
end
2016-10-04 20:10:50 +00:00
graph(::typeof(*), args...) = *(args...)
2016-09-27 01:16:49 +00:00
graph(::typeof(+), args...) = +(args...)
2016-09-29 19:50:43 +00:00
graph(::typeof(softmax), x) = nn.softmax(x)
2016-09-29 20:28:53 +00:00
graph(::typeof(relu), x) = nn.relu(x)
2016-10-04 21:23:37 +00:00
graph(::typeof(tanh), x) = tanh(x)
2016-10-10 22:04:26 +00:00
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x),Int32(-1)]))
2016-09-29 20:28:53 +00:00
graph(::Input, x) = x
2016-09-27 01:16:49 +00:00
2016-10-04 21:23:37 +00:00
graph(c::Conv2D, x) =
nn.conv2d(x, graph(c.filter), [1,c.stride...,1], "VALID")
graph(p::MaxPool, x) =
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
TensorFlow.Tensor(m::Flux.Model, args...) = graph(m, args...)
2016-10-04 20:10:50 +00:00
# Treat the first dimension as the batch index
# TODO: custom data type for this
batch(x) = reshape(x, (1,size(x)...))
batch(xs...) = vcat(map(batch, xs)...)
unbatch(xs) = reshape(xs, size(xs)[2:end])
2016-09-27 01:16:49 +00:00
type Model
session::Session
inputs::Vector{Tensor}
graph::Tensor
2016-09-28 16:15:41 +00:00
grad::Tensor
2016-09-27 01:16:49 +00:00
end
function tf(model)
2016-09-29 19:50:43 +00:00
sess = Session(Graph())
2016-10-04 21:36:56 +00:00
input = placeholder(Float32)
2016-09-27 01:16:49 +00:00
g = graph(model, input)
run(sess, initialize_all_variables())
2016-09-28 16:15:41 +00:00
Model(sess, [input], g, gradients(g, input))
2016-09-27 01:16:49 +00:00
end
function (m::Model)(args...)
@assert length(args) == length(m.inputs)
2016-10-04 20:10:50 +00:00
unbatch(run(m.session, m.graph, Dict(zip(m.inputs, map(batch, args)))))
2016-09-27 01:16:49 +00:00
end
2016-09-28 16:15:41 +00:00
function Flux.back!(m::Model, Δ, args...)
@assert length(args) == length(m.inputs)
# TODO: keyword arguments to `gradients`
run(m.session, m.grad, Dict(zip(m.inputs, args)))
end
function Flux.update!(m::Model)
error("update! is not yet supported on TensorFlow models")
end
2016-09-29 19:50:43 +00:00
function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
loss = (y, y) -> reduce_sum((y - y).^2)/2,
opt = TensorFlow.train.GradientDescentOptimizer(η))
i = 0
2016-10-04 21:36:56 +00:00
Y = placeholder(Float32)
2016-09-29 19:50:43 +00:00
Loss = loss(m.graph, Y)
minimize_op = TensorFlow.train.minimize(opt, Loss)
for e in 1:epoch
info("Epoch $e\n")
@progress for (x, y) in train
2016-10-04 20:10:50 +00:00
y, cur_loss, _ = run(m.session, vcat(m.graph, Loss, minimize_op),
Dict(m.inputs[1]=>batch(x), Y=>batch(y)))
2016-09-29 20:28:53 +00:00
if i % 5000 == 0
@show y
@show accuracy(m, test)
end
2016-09-29 19:50:43 +00:00
i += 1
end
end
end
2016-10-10 22:04:26 +00:00
type Op
f
shape
end
Op(f) = Op(f, (d...) -> nothing)
graph(op::Op, xs...) = op.f(xs...)
Flux.shape(op::Op, d...) = op.shape(d...)
2016-09-27 01:16:49 +00:00
end