remove special-cased training code
splits model and exec to allow multi inputs
This commit is contained in:
parent
d4ee8a6a2a
commit
ba54552be5
@ -1,39 +1,40 @@
|
|||||||
struct Model
|
using Flux: mapt
|
||||||
model::Any
|
|
||||||
|
struct Exec
|
||||||
session::Session
|
session::Session
|
||||||
params::Dict{Flux.Param,Tensor}
|
input::Any
|
||||||
stacks::Dict
|
|
||||||
inputs::Vector{Tensor}
|
|
||||||
output::Any
|
output::Any
|
||||||
|
params::Dict{Flux.Param,Tensor}
|
||||||
|
stacks::Dict{Any,Any}
|
||||||
end
|
end
|
||||||
|
|
||||||
function makesession(model, inputs; session = Session(Graph()))
|
function makesession(model, inputs; session = Session(Graph()))
|
||||||
params, stacks, output = tograph(model, inputs...)
|
params, stacks, output = tograph(model, inputs...)
|
||||||
run(session, global_variables_initializer())
|
run(session, global_variables_initializer())
|
||||||
Model(model, session, params, stacks, inputs, output)
|
Exec(session, inputs, output, params, stacks)
|
||||||
end
|
end
|
||||||
|
|
||||||
function makesession(model, n::Integer; session = Session(Graph()))
|
function makesession(model, n::Integer; session = Session(Graph()))
|
||||||
makesession(model, [placeholder(Float32) for _ = 1:n], session = session)
|
makesession(model, [placeholder(Float32) for _ = 1:n], session = session)
|
||||||
end
|
end
|
||||||
|
|
||||||
tf(model) = makesession(model, 1)
|
function (m::Exec)(args...)
|
||||||
|
@assert length(args) == length(m.input)
|
||||||
function storeparams!(sess, params)
|
run(m.session, m.output, Dict(zip(m.input, args)))
|
||||||
for (p, t) in params
|
|
||||||
p.x = run(sess, t)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
storeparams!(m::Model) = storeparams!(m.session, m.params)
|
mutable struct Model
|
||||||
|
model::Any
|
||||||
function runmodel(m::Model, args...)
|
exec::Exec
|
||||||
@assert length(args) == length(m.inputs)
|
Model(model) = new(model)
|
||||||
run(m.session, m.output, Dict(zip(m.inputs, args)))
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::Model)(x)
|
tf(model) = Model(model)
|
||||||
@tferr m.stacks runmodel(m, convert.(Float32, x))
|
|
||||||
|
function (m::Model)(args...)
|
||||||
|
args = mapt(x->convert.(Float32, x),args)
|
||||||
|
isdefined(m, :graph) || (m.exec = makesession(m.model, length(args)))
|
||||||
|
@tferr m.exec.stacks m.exec(args...)
|
||||||
end
|
end
|
||||||
|
|
||||||
for f in :[back!, update!].args
|
for f in :[back!, update!].args
|
||||||
@ -41,27 +42,3 @@ for f in :[back!, update!].args
|
|||||||
error($(string(f)) * " is not yet supported on TensorFlow models")
|
error($(string(f)) * " is not yet supported on TensorFlow models")
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
import Juno: info
|
|
||||||
|
|
||||||
function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
|
||||||
loss = (y, y′) -> reduce_sum((y - y′).^2)/2,
|
|
||||||
opt = TensorFlow.train.GradientDescentOptimizer(η))
|
|
||||||
i = 0
|
|
||||||
Y = placeholder(Float32)
|
|
||||||
Loss = loss(m.output, Y)
|
|
||||||
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
|
||||||
for e in 1:epoch
|
|
||||||
info("Epoch $e\n")
|
|
||||||
@progress for (x, y) in train
|
|
||||||
y, cur_loss, _ = run(m.session, [m.output, Loss, minimize_op],
|
|
||||||
Dict(m.inputs[1] => batchone(convertel(Float32, x)),
|
|
||||||
Y => batchone(convertel(Float32, y))))
|
|
||||||
if i % 5000 == 0
|
|
||||||
@show y
|
|
||||||
@show accuracy(m, test)
|
|
||||||
end
|
|
||||||
i += 1
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
@ -3,12 +3,16 @@ Flux.loadtf()
|
|||||||
|
|
||||||
@testset "TensorFlow" begin
|
@testset "TensorFlow" begin
|
||||||
|
|
||||||
xs = rand(1, 20)
|
xs, ys = rand(1, 20), rand(1, 20)
|
||||||
d = Affine(20, 10)
|
d = Affine(20, 10)
|
||||||
|
|
||||||
dt = tf(d)
|
dt = tf(d)
|
||||||
@test d(xs) ≈ dt(xs)
|
@test d(xs) ≈ dt(xs)
|
||||||
|
|
||||||
|
m = Multi(20, 15)
|
||||||
|
mm = tf(m)
|
||||||
|
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
|
||||||
|
|
||||||
@testset "Tensor interface" begin
|
@testset "Tensor interface" begin
|
||||||
sess = TensorFlow.Session()
|
sess = TensorFlow.Session()
|
||||||
X = placeholder(Float32)
|
X = placeholder(Float32)
|
||||||
|
Loading…
Reference in New Issue
Block a user