tf reorg
This commit is contained in:
parent
82d69757c7
commit
0ad569596b
53
src/backend/tensorflow/graph.jl
Normal file
53
src/backend/tensorflow/graph.jl
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import Flow: Constant, postwalk, value, inputs, constant
|
||||||
|
import TensorFlow: RawTensor
|
||||||
|
|
||||||
|
cvalue(x) = x
|
||||||
|
cvalue(c::Constant) = c.value
|
||||||
|
cvalue(v::Vertex) = cvalue(value(v))
|
||||||
|
|
||||||
|
graph(x::Tensor) = x
|
||||||
|
|
||||||
|
graph(::typeof(*), args...) = *(args...)
|
||||||
|
graph(::typeof(+), args...) = +(args...)
|
||||||
|
graph(::typeof(softmax), x) = nn.softmax(x)
|
||||||
|
graph(::typeof(relu), x) = nn.relu(x)
|
||||||
|
graph(::typeof(tanh), x) = tanh(x)
|
||||||
|
|
||||||
|
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
||||||
|
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
||||||
|
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)]))
|
||||||
|
graph(r::Reshape, x) = reshape(x, pack([batchsize(x), map(Int32, r.dims)...]))
|
||||||
|
|
||||||
|
graph(::Input, x) = x
|
||||||
|
|
||||||
|
graph(c::Conv2D, x) =
|
||||||
|
nn.conv2d(x, graph(c.filter), [1,c.stride...,1], "VALID")
|
||||||
|
|
||||||
|
graph(p::MaxPool, x) =
|
||||||
|
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
|
||||||
|
|
||||||
|
type Op
|
||||||
|
f
|
||||||
|
shape
|
||||||
|
end
|
||||||
|
|
||||||
|
Op(f) = Op(f, (d...) -> nothing)
|
||||||
|
|
||||||
|
graph(op::Op, xs...) = op.f(xs...)
|
||||||
|
Flux.shape(op::Op, d...) = op.shape(d...)
|
||||||
|
|
||||||
|
# TODO: detect variable reuse
|
||||||
|
graph{T<:AArray}(p::Flux.Param{T}) = Variable(p.x)
|
||||||
|
|
||||||
|
function graph(model::Model, args...)
|
||||||
|
g = Flux.graph(model)
|
||||||
|
g ≠ nothing || error("No graph for $model")
|
||||||
|
g = spliceinputs(g, map(constant, args)...) |> detuple
|
||||||
|
postwalk(g) do v
|
||||||
|
vertex(graph(cvalue(v), cvalue.(inputs(v))...))
|
||||||
|
end |> value
|
||||||
|
end
|
||||||
|
|
||||||
|
TensorFlow.Tensor(m::Flux.Model, args...) = graph(m, args...)
|
||||||
|
|
||||||
|
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
56
src/backend/tensorflow/model.jl
Normal file
56
src/backend/tensorflow/model.jl
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
type Model
|
||||||
|
session::Session
|
||||||
|
inputs::Vector{Tensor}
|
||||||
|
graph::Tensor
|
||||||
|
grad::Tensor
|
||||||
|
end
|
||||||
|
|
||||||
|
function tf(model)
|
||||||
|
sess = Session(Graph())
|
||||||
|
input = placeholder(Float32)
|
||||||
|
g = graph(model, input)
|
||||||
|
run(sess, initialize_all_variables())
|
||||||
|
Model(sess, [input], g, gradients(g, input))
|
||||||
|
end
|
||||||
|
|
||||||
|
batch(x) = Batch((x,))
|
||||||
|
|
||||||
|
function (m::Model)(args::Batch...)
|
||||||
|
@assert length(args) == length(m.inputs)
|
||||||
|
run(m.session, m.graph, Dict(zip(m.inputs, args)))
|
||||||
|
end
|
||||||
|
|
||||||
|
(m::Model)(args...) = m(map(batch, args)...)
|
||||||
|
|
||||||
|
function Flux.back!(m::Model, Δ, args...)
|
||||||
|
@assert length(args) == length(m.inputs)
|
||||||
|
# TODO: keyword arguments to `gradients`
|
||||||
|
run(m.session, m.grad, Dict(zip(m.inputs, args)))
|
||||||
|
end
|
||||||
|
|
||||||
|
function Flux.update!(m::Model)
|
||||||
|
error("update! is not yet supported on TensorFlow models")
|
||||||
|
end
|
||||||
|
|
||||||
|
import Juno: info
|
||||||
|
|
||||||
|
function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
||||||
|
loss = (y, y′) -> reduce_sum((y - y′).^2)/2,
|
||||||
|
opt = TensorFlow.train.GradientDescentOptimizer(η))
|
||||||
|
i = 0
|
||||||
|
Y = placeholder(Float32)
|
||||||
|
Loss = loss(m.graph, Y)
|
||||||
|
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
||||||
|
for e in 1:epoch
|
||||||
|
info("Epoch $e\n")
|
||||||
|
@progress for (x, y) in train
|
||||||
|
y, cur_loss, _ = run(m.session, vcat(m.graph, Loss, minimize_op),
|
||||||
|
Dict(m.inputs[1]=>batch(x), Y=>batch(y)))
|
||||||
|
if i % 5000 == 0
|
||||||
|
@show y
|
||||||
|
@show accuracy(m, test)
|
||||||
|
end
|
||||||
|
i += 1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
@ -1,117 +1,11 @@
|
|||||||
module TF
|
module TF
|
||||||
|
|
||||||
using ..Flux, Flow, TensorFlow, Juno
|
using ..Flux, Flow, TensorFlow, Juno
|
||||||
import Flow: Constant, postwalk, value, inputs, constant
|
|
||||||
import Flux: accuracy, spliceinputs, detuple
|
import Flux: accuracy, spliceinputs, detuple
|
||||||
import TensorFlow: RawTensor
|
|
||||||
import Juno: info
|
|
||||||
|
|
||||||
export tf
|
export tf
|
||||||
|
|
||||||
cvalue(x) = x
|
include("graph.jl")
|
||||||
cvalue(c::Constant) = c.value
|
include("model.jl")
|
||||||
cvalue(v::Vertex) = cvalue(value(v))
|
|
||||||
|
|
||||||
graph(x::Tensor) = x
|
|
||||||
|
|
||||||
# TODO: detect variable reuse
|
|
||||||
graph{T<:AArray}(p::Flux.Param{T}) = Variable(p.x)
|
|
||||||
|
|
||||||
function graph(model::Model, args...)
|
|
||||||
g = Flux.graph(model)
|
|
||||||
g ≠ nothing || error("No graph for $model")
|
|
||||||
g = spliceinputs(g, map(constant, args)...) |> detuple
|
|
||||||
postwalk(g) do v
|
|
||||||
vertex(graph(cvalue(v), cvalue.(inputs(v))...))
|
|
||||||
end |> value
|
|
||||||
end
|
|
||||||
|
|
||||||
graph(::typeof(*), args...) = *(args...)
|
|
||||||
graph(::typeof(+), args...) = +(args...)
|
|
||||||
graph(::typeof(softmax), x) = nn.softmax(x)
|
|
||||||
graph(::typeof(relu), x) = nn.relu(x)
|
|
||||||
graph(::typeof(tanh), x) = tanh(x)
|
|
||||||
|
|
||||||
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
|
||||||
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
|
||||||
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)]))
|
|
||||||
graph(r::Reshape, x) = reshape(x, pack([batchsize(x), map(Int32, r.dims)...]))
|
|
||||||
|
|
||||||
graph(::Input, x) = x
|
|
||||||
|
|
||||||
graph(c::Conv2D, x) =
|
|
||||||
nn.conv2d(x, graph(c.filter), [1,c.stride...,1], "VALID")
|
|
||||||
|
|
||||||
graph(p::MaxPool, x) =
|
|
||||||
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
|
|
||||||
|
|
||||||
TensorFlow.Tensor(m::Flux.Model, args...) = graph(m, args...)
|
|
||||||
|
|
||||||
type Model
|
|
||||||
session::Session
|
|
||||||
inputs::Vector{Tensor}
|
|
||||||
graph::Tensor
|
|
||||||
grad::Tensor
|
|
||||||
end
|
|
||||||
|
|
||||||
function tf(model)
|
|
||||||
sess = Session(Graph())
|
|
||||||
input = placeholder(Float32)
|
|
||||||
g = graph(model, input)
|
|
||||||
run(sess, initialize_all_variables())
|
|
||||||
Model(sess, [input], g, gradients(g, input))
|
|
||||||
end
|
|
||||||
|
|
||||||
batch(x) = Batch((x,))
|
|
||||||
|
|
||||||
RawTensor(data::Batch) = RawTensor(rawbatch(data))
|
|
||||||
|
|
||||||
function (m::Model)(args::Batch...)
|
|
||||||
@assert length(args) == length(m.inputs)
|
|
||||||
run(m.session, m.graph, Dict(zip(m.inputs, args)))
|
|
||||||
end
|
|
||||||
|
|
||||||
(m::Model)(args...) = m(map(batch, args)...)
|
|
||||||
|
|
||||||
function Flux.back!(m::Model, Δ, args...)
|
|
||||||
@assert length(args) == length(m.inputs)
|
|
||||||
# TODO: keyword arguments to `gradients`
|
|
||||||
run(m.session, m.grad, Dict(zip(m.inputs, args)))
|
|
||||||
end
|
|
||||||
|
|
||||||
function Flux.update!(m::Model)
|
|
||||||
error("update! is not yet supported on TensorFlow models")
|
|
||||||
end
|
|
||||||
|
|
||||||
function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
|
||||||
loss = (y, y′) -> reduce_sum((y - y′).^2)/2,
|
|
||||||
opt = TensorFlow.train.GradientDescentOptimizer(η))
|
|
||||||
i = 0
|
|
||||||
Y = placeholder(Float32)
|
|
||||||
Loss = loss(m.graph, Y)
|
|
||||||
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
|
||||||
for e in 1:epoch
|
|
||||||
info("Epoch $e\n")
|
|
||||||
@progress for (x, y) in train
|
|
||||||
y, cur_loss, _ = run(m.session, vcat(m.graph, Loss, minimize_op),
|
|
||||||
Dict(m.inputs[1]=>batch(x), Y=>batch(y)))
|
|
||||||
if i % 5000 == 0
|
|
||||||
@show y
|
|
||||||
@show accuracy(m, test)
|
|
||||||
end
|
|
||||||
i += 1
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
type Op
|
|
||||||
f
|
|
||||||
shape
|
|
||||||
end
|
|
||||||
|
|
||||||
Op(f) = Op(f, (d...) -> nothing)
|
|
||||||
|
|
||||||
graph(op::Op, xs...) = op.f(xs...)
|
|
||||||
Flux.shape(op::Op, d...) = op.shape(d...)
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user