use params instead of vars
This commit is contained in:
parent
796d7d7e99
commit
b35f50571c
@ -52,7 +52,10 @@ interp(ctx, c::Conv2D, x) =
|
||||
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
|
||||
haskey(ctx[:params], p.value) ?
|
||||
ctx[:params][p.value] :
|
||||
(ctx[:params][p.value] = Variable(convertel(Float32, p.value.x)))
|
||||
(ctx[:params][p.value] =
|
||||
ctx[:variables] ?
|
||||
Variable(Float32.(p.value.x)) :
|
||||
placeholder(Float32))
|
||||
|
||||
interp(ctx, p::Constant) = p.value
|
||||
|
||||
@ -63,14 +66,15 @@ function interp(ctx, model, args...)
|
||||
interpret(ctx, g, interpv(ctx, args)...)
|
||||
end
|
||||
|
||||
function tograph(model, args...)
|
||||
function tograph(model, args...; variables = false)
|
||||
ctx = Context(mux(iline, ilambda, imap, interp),
|
||||
params = ObjectIdDict(), stacks = Dict())
|
||||
params = ObjectIdDict(), stacks = Dict(), variables = variables)
|
||||
out = interp(ctx, model, map(constant, args)...)
|
||||
return ctx[:params], ctx[:stacks], out
|
||||
end
|
||||
|
||||
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[3]
|
||||
TensorFlow.Tensor(m::Flux.Model, args...) =
|
||||
tograph(m, args...; variables = true)[3]
|
||||
|
||||
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
||||
|
||||
|
@ -22,7 +22,9 @@ dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
|
||||
|
||||
function (m::Exec)(args...)
|
||||
shapecheckt(m.input, args)
|
||||
retuple(run(m.session, m.output, dictt(m.input, args)))
|
||||
idict = dictt(m.input, args)
|
||||
pdict = Dict(t => p.x for (p, t) in m.params)
|
||||
retuple(run(m.session, m.output, merge(idict, pdict)))
|
||||
end
|
||||
|
||||
mutable struct Model
|
||||
@ -34,7 +36,7 @@ end
|
||||
tf(model) = Model(model)
|
||||
|
||||
function (m::Model)(args...)
|
||||
args = mapt(x->convert.(Float32, x), args)
|
||||
args = mapt(x->Float32.(x), args)
|
||||
isdefined(m, :graph) || (m.exec = makesession(m.model, args))
|
||||
@tferr m.exec.stacks m.exec(args...)
|
||||
end
|
||||
|
@ -27,7 +27,7 @@ end
|
||||
Y = Tensor(d, X)
|
||||
run(sess, global_variables_initializer())
|
||||
|
||||
@test run(sess, Y, Dict(X=>Float32.(xs))) ≈ d(xs)
|
||||
@test run(sess, Y, Dict(X=>xs)) ≈ d(xs)
|
||||
end
|
||||
|
||||
@testset "Stack Traces" begin
|
||||
|
Loading…
Reference in New Issue
Block a user