get basic training working
This commit is contained in:
parent
9e35bcd4b7
commit
a2aade718d
@ -1,6 +1,10 @@
|
||||
module TF
|
||||
|
||||
using ..Flux, Flow, TensorFlow
|
||||
import Juno: info
|
||||
import Flux: accuracy
|
||||
|
||||
export tf
|
||||
|
||||
# Workaround for tensor display bug
|
||||
using Juno
|
||||
@ -12,10 +16,8 @@ cvalue(v::Vertex) = cvalue(value(v))
|
||||
|
||||
graph(x::Tensor) = x
|
||||
|
||||
matrixify(xs) = xs
|
||||
matrixify(xs::Vector) = xs[:,1:1]
|
||||
# TODO: detect variable reuse
|
||||
graph{T<:AArray}(p::Flux.Param{T}) = Variable(matrixify(p.x))
|
||||
graph{T<:AArray}(p::Flux.Param{T}) = Variable(p.x')
|
||||
|
||||
function graph(model::Model, args...)
|
||||
g = Flux.graph(model)
|
||||
@ -28,8 +30,9 @@ function graph(model::Model, args...)
|
||||
end |> value
|
||||
end
|
||||
|
||||
graph(::typeof(*), args...) = *(args...)
|
||||
graph(::typeof(*), args...) = *(reverse(args)...)
|
||||
graph(::typeof(+), args...) = +(args...)
|
||||
graph(::typeof(softmax), x) = nn.softmax(x)
|
||||
|
||||
type Model
|
||||
session::Session
|
||||
@ -41,7 +44,7 @@ end
|
||||
Media.render(::Juno.Clipboard, ::Model) = "Flux.TF.Model()"
|
||||
|
||||
function tf(model)
|
||||
sess = Session()
|
||||
sess = Session(Graph())
|
||||
input = placeholder(Float64)
|
||||
g = graph(model, input)
|
||||
run(sess, initialize_all_variables())
|
||||
@ -50,7 +53,7 @@ end
|
||||
|
||||
function (m::Model)(args...)
|
||||
@assert length(args) == length(m.inputs)
|
||||
run(m.session, m.graph, Dict(zip(m.inputs, args)))
|
||||
run(m.session, m.graph, Dict(zip(m.inputs, map(transpose, args))))'
|
||||
end
|
||||
|
||||
function Flux.back!(m::Model, Δ, args...)
|
||||
@ -63,4 +66,22 @@ function Flux.update!(m::Model)
|
||||
error("update! is not yet supported on TensorFlow models")
|
||||
end
|
||||
|
||||
function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
||||
loss = (y, y′) -> reduce_sum((y - y′).^2)/2,
|
||||
opt = TensorFlow.train.GradientDescentOptimizer(η))
|
||||
i = 0
|
||||
Y = placeholder(Float64)
|
||||
Loss = loss(m.graph, Y)
|
||||
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
||||
run(m.session, initialize_all_variables())
|
||||
for e in 1:epoch
|
||||
info("Epoch $e\n")
|
||||
@progress for (x, y) in train
|
||||
y, cur_loss, _ = run(m.session, vcat(m.graph, Loss, minimize_op), Dict(m.inputs[1]=>x', Y=>y'))
|
||||
i % 1000 == 0 && @show accuracy(m, test)
|
||||
i += 1
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
end
|
||||
|
@ -5,7 +5,7 @@ const AArray = AbstractArray
|
||||
onehot(label, labels) = [i == label for i in labels]
|
||||
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
|
||||
|
||||
initn(dims...) = randn(dims...)/100
|
||||
initn(dims...) = randn(dims...)/1000
|
||||
|
||||
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
||||
i = 0
|
||||
@ -24,7 +24,7 @@ function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
||||
return m
|
||||
end
|
||||
|
||||
function accuracy(m::Model, data)
|
||||
function accuracy(m, data)
|
||||
correct = 0
|
||||
for (x, y) in data
|
||||
onecold(m(x)) == onecold(y) && (correct += 1)
|
||||
|
Loading…
Reference in New Issue
Block a user