use params instead of vars
This commit is contained in:
parent
796d7d7e99
commit
b35f50571c
@ -52,7 +52,10 @@ interp(ctx, c::Conv2D, x) =
|
|||||||
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
|
interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) =
|
||||||
haskey(ctx[:params], p.value) ?
|
haskey(ctx[:params], p.value) ?
|
||||||
ctx[:params][p.value] :
|
ctx[:params][p.value] :
|
||||||
(ctx[:params][p.value] = Variable(convertel(Float32, p.value.x)))
|
(ctx[:params][p.value] =
|
||||||
|
ctx[:variables] ?
|
||||||
|
Variable(Float32.(p.value.x)) :
|
||||||
|
placeholder(Float32))
|
||||||
|
|
||||||
interp(ctx, p::Constant) = p.value
|
interp(ctx, p::Constant) = p.value
|
||||||
|
|
||||||
@ -63,14 +66,15 @@ function interp(ctx, model, args...)
|
|||||||
interpret(ctx, g, interpv(ctx, args)...)
|
interpret(ctx, g, interpv(ctx, args)...)
|
||||||
end
|
end
|
||||||
|
|
||||||
function tograph(model, args...)
|
function tograph(model, args...; variables = false)
|
||||||
ctx = Context(mux(iline, ilambda, imap, interp),
|
ctx = Context(mux(iline, ilambda, imap, interp),
|
||||||
params = ObjectIdDict(), stacks = Dict())
|
params = ObjectIdDict(), stacks = Dict(), variables = variables)
|
||||||
out = interp(ctx, model, map(constant, args)...)
|
out = interp(ctx, model, map(constant, args)...)
|
||||||
return ctx[:params], ctx[:stacks], out
|
return ctx[:params], ctx[:stacks], out
|
||||||
end
|
end
|
||||||
|
|
||||||
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[3]
|
TensorFlow.Tensor(m::Flux.Model, args...) =
|
||||||
|
tograph(m, args...; variables = true)[3]
|
||||||
|
|
||||||
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
||||||
|
|
||||||
|
@ -22,7 +22,9 @@ dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
|
|||||||
|
|
||||||
function (m::Exec)(args...)
|
function (m::Exec)(args...)
|
||||||
shapecheckt(m.input, args)
|
shapecheckt(m.input, args)
|
||||||
retuple(run(m.session, m.output, dictt(m.input, args)))
|
idict = dictt(m.input, args)
|
||||||
|
pdict = Dict(t => p.x for (p, t) in m.params)
|
||||||
|
retuple(run(m.session, m.output, merge(idict, pdict)))
|
||||||
end
|
end
|
||||||
|
|
||||||
mutable struct Model
|
mutable struct Model
|
||||||
@ -34,7 +36,7 @@ end
|
|||||||
tf(model) = Model(model)
|
tf(model) = Model(model)
|
||||||
|
|
||||||
function (m::Model)(args...)
|
function (m::Model)(args...)
|
||||||
args = mapt(x->convert.(Float32, x), args)
|
args = mapt(x->Float32.(x), args)
|
||||||
isdefined(m, :graph) || (m.exec = makesession(m.model, args))
|
isdefined(m, :graph) || (m.exec = makesession(m.model, args))
|
||||||
@tferr m.exec.stacks m.exec(args...)
|
@tferr m.exec.stacks m.exec(args...)
|
||||||
end
|
end
|
||||||
|
@ -27,7 +27,7 @@ end
|
|||||||
Y = Tensor(d, X)
|
Y = Tensor(d, X)
|
||||||
run(sess, global_variables_initializer())
|
run(sess, global_variables_initializer())
|
||||||
|
|
||||||
@test run(sess, Y, Dict(X=>Float32.(xs))) ≈ d(xs)
|
@test run(sess, Y, Dict(X=>xs)) ≈ d(xs)
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Stack Traces" begin
|
@testset "Stack Traces" begin
|
||||||
|
Loading…
Reference in New Issue
Block a user