use params instead of vars

This commit is contained in:
Mike J Innes 2017-05-01 18:27:52 +01:00
parent 796d7d7e99
commit b35f50571c
3 changed files with 13 additions and 7 deletions

View File

@ -52,7 +52,10 @@ interp(ctx, c::Conv2D, x) =
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) = interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
haskey(ctx[:params], p.value) ? haskey(ctx[:params], p.value) ?
ctx[:params][p.value] : ctx[:params][p.value] :
(ctx[:params][p.value] = Variable(convertel(Float32, p.value.x))) (ctx[:params][p.value] =
ctx[:variables] ?
Variable(Float32.(p.value.x)) :
placeholder(Float32))
interp(ctx, p::Constant) = p.value interp(ctx, p::Constant) = p.value
@ -63,14 +66,15 @@ function interp(ctx, model, args...)
interpret(ctx, g, interpv(ctx, args)...) interpret(ctx, g, interpv(ctx, args)...)
end end
function tograph(model, args...) function tograph(model, args...; variables = false)
ctx = Context(mux(iline, ilambda, imap, interp), ctx = Context(mux(iline, ilambda, imap, interp),
params = ObjectIdDict(), stacks = Dict()) params = ObjectIdDict(), stacks = Dict(), variables = variables)
out = interp(ctx, model, map(constant, args)...) out = interp(ctx, model, map(constant, args)...)
return ctx[:params], ctx[:stacks], out return ctx[:params], ctx[:stacks], out
end end
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[3] TensorFlow.Tensor(m::Flux.Model, args...) =
tograph(m, args...; variables = true)[3]
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data)) RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))

View File

@ -22,7 +22,9 @@ dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
function (m::Exec)(args...) function (m::Exec)(args...)
shapecheckt(m.input, args) shapecheckt(m.input, args)
retuple(run(m.session, m.output, dictt(m.input, args))) idict = dictt(m.input, args)
pdict = Dict(t => p.x for (p, t) in m.params)
retuple(run(m.session, m.output, merge(idict, pdict)))
end end
mutable struct Model mutable struct Model
@ -34,7 +36,7 @@ end
tf(model) = Model(model) tf(model) = Model(model)
function (m::Model)(args...) function (m::Model)(args...)
args = mapt(x->convert.(Float32, x), args) args = mapt(x->Float32.(x), args)
isdefined(m, :graph) || (m.exec = makesession(m.model, args)) isdefined(m, :graph) || (m.exec = makesession(m.model, args))
@tferr m.exec.stacks m.exec(args...) @tferr m.exec.stacks m.exec(args...)
end end

View File

@ -27,7 +27,7 @@ end
Y = Tensor(d, X) Y = Tensor(d, X)
run(sess, global_variables_initializer()) run(sess, global_variables_initializer())
@test run(sess, Y, Dict(X=>Float32.(xs))) d(xs) @test run(sess, Y, Dict(X=>xs)) d(xs)
end end
@testset "Stack Traces" begin @testset "Stack Traces" begin