remove special-cased training code
splits model and exec to allow multi inputs
This commit is contained in:
parent
d4ee8a6a2a
commit
ba54552be5
@ -1,39 +1,40 @@
|
||||
struct Model
|
||||
model::Any
|
||||
using Flux: mapt
|
||||
|
||||
struct Exec
|
||||
session::Session
|
||||
params::Dict{Flux.Param,Tensor}
|
||||
stacks::Dict
|
||||
inputs::Vector{Tensor}
|
||||
input::Any
|
||||
output::Any
|
||||
params::Dict{Flux.Param,Tensor}
|
||||
stacks::Dict{Any,Any}
|
||||
end
|
||||
|
||||
function makesession(model, inputs; session = Session(Graph()))
|
||||
params, stacks, output = tograph(model, inputs...)
|
||||
run(session, global_variables_initializer())
|
||||
Model(model, session, params, stacks, inputs, output)
|
||||
Exec(session, inputs, output, params, stacks)
|
||||
end
|
||||
|
||||
function makesession(model, n::Integer; session = Session(Graph()))
|
||||
makesession(model, [placeholder(Float32) for _ = 1:n], session = session)
|
||||
end
|
||||
|
||||
tf(model) = makesession(model, 1)
|
||||
|
||||
function storeparams!(sess, params)
|
||||
for (p, t) in params
|
||||
p.x = run(sess, t)
|
||||
end
|
||||
function (m::Exec)(args...)
|
||||
@assert length(args) == length(m.input)
|
||||
run(m.session, m.output, Dict(zip(m.input, args)))
|
||||
end
|
||||
|
||||
storeparams!(m::Model) = storeparams!(m.session, m.params)
|
||||
|
||||
function runmodel(m::Model, args...)
|
||||
@assert length(args) == length(m.inputs)
|
||||
run(m.session, m.output, Dict(zip(m.inputs, args)))
|
||||
mutable struct Model
|
||||
model::Any
|
||||
exec::Exec
|
||||
Model(model) = new(model)
|
||||
end
|
||||
|
||||
function (m::Model)(x)
|
||||
@tferr m.stacks runmodel(m, convert.(Float32, x))
|
||||
tf(model) = Model(model)
|
||||
|
||||
function (m::Model)(args...)
|
||||
args = mapt(x->convert.(Float32, x),args)
|
||||
isdefined(m, :graph) || (m.exec = makesession(m.model, length(args)))
|
||||
@tferr m.exec.stacks m.exec(args...)
|
||||
end
|
||||
|
||||
for f in :[back!, update!].args
|
||||
@ -41,27 +42,3 @@ for f in :[back!, update!].args
|
||||
error($(string(f)) * " is not yet supported on TensorFlow models")
|
||||
end
|
||||
end
|
||||
|
||||
import Juno: info
|
||||
|
||||
function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
||||
loss = (y, y′) -> reduce_sum((y - y′).^2)/2,
|
||||
opt = TensorFlow.train.GradientDescentOptimizer(η))
|
||||
i = 0
|
||||
Y = placeholder(Float32)
|
||||
Loss = loss(m.output, Y)
|
||||
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
||||
for e in 1:epoch
|
||||
info("Epoch $e\n")
|
||||
@progress for (x, y) in train
|
||||
y, cur_loss, _ = run(m.session, [m.output, Loss, minimize_op],
|
||||
Dict(m.inputs[1] => batchone(convertel(Float32, x)),
|
||||
Y => batchone(convertel(Float32, y))))
|
||||
if i % 5000 == 0
|
||||
@show y
|
||||
@show accuracy(m, test)
|
||||
end
|
||||
i += 1
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -3,12 +3,16 @@ Flux.loadtf()
|
||||
|
||||
@testset "TensorFlow" begin
|
||||
|
||||
xs = rand(1, 20)
|
||||
xs, ys = rand(1, 20), rand(1, 20)
|
||||
d = Affine(20, 10)
|
||||
|
||||
dt = tf(d)
|
||||
@test d(xs) ≈ dt(xs)
|
||||
|
||||
m = Multi(20, 15)
|
||||
mm = tf(m)
|
||||
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
|
||||
|
||||
@testset "Tensor interface" begin
|
||||
sess = TensorFlow.Session()
|
||||
X = placeholder(Float32)
|
||||
|
Loading…
Reference in New Issue
Block a user