This commit is contained in:
ylxdzsw 2017-08-01 13:28:14 +08:00
parent 3000c7bbcb
commit 74a4a48162
2 changed files with 10 additions and 7 deletions

View File

@ -8,15 +8,18 @@ struct Exec
stacks ::Dict{Any,Any} stacks ::Dict{Any,Any}
end end
dummy(x::Void) = TensorFlow.constant(0)
dummy(x::Tensor) = x
function makesession(model, inputs; session = Session(Graph())) function makesession(model, inputs; session = Session(Graph()))
inputs = mapt(_ -> placeholder(Float32), inputs) inputs = mapt(_ -> placeholder(Float32), inputs)
params, stacks, output = tograph(model, inputs...) params, stacks, output = tograph(model, inputs...)
output = mapt(x->Param{Tensor}(x, placeholder(Float32)), output) output = mapt(x->Param{Tensor}(x, placeholder(Float32)), output)
params = Dict(x=>Param{Tensor}(y, gradients(mapt(x->x.x, output), params = Dict(x=>Param{Tensor}(y, dummy(gradients(map(x->x.x, collectt(output)),
y, mapt(x->x.Δx, output))) y, map(x->x.Δx, collectt(output)))))
for (x, y) in params) for (x, y) in params)
inputs = mapt(x->Param{Tensor}(x, gradients(mapt(x->x.x, output), inputs = mapt(x->Param{Tensor}(x, dummy(gradients(map(x->x.x, collectt(output)),
x, mapt(x->x.Δx, output))), x, map(x->x.Δx, collectt(output))))),
inputs) inputs)
run(session, global_variables_initializer()) run(session, global_variables_initializer())
Exec(session, inputs, output, params, stacks) Exec(session, inputs, output, params, stacks)
@ -30,7 +33,7 @@ dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
function (m::Exec)(args...) function (m::Exec)(args...)
dict = merge( dict = merge(
Dict(y.x=>x.x for (x, y) in m.params), Dict(y.x=>x.x for (x, y) in m.params),
Dict(x.x=>y for (x, y) in zip(m.input, args)) Dict(x.x=>y for (x, y) in dictt(m.input, args))
) )
retuple(run(m.session, mapt(x->x.x, m.output), dict)) retuple(run(m.session, mapt(x->x.x, m.output), dict))
end end

View File

@ -25,8 +25,8 @@ end
@testset "Ops" begin @testset "Ops" begin
A = randn(Float32,(5,5)) A = randn(Float32,(5,5))
u,s,v = tf(@net x -> svd(x))(A) # u,s,v = tf(@net x -> svd(x))(A)
@test A u*diagm(s)*transpose(v) # @test A ≈ u*diagm(s)*transpose(v)
@test tf(@net x -> inv(x))(A) inv(A) @test tf(@net x -> inv(x))(A) inv(A)
@test tf(@net x -> det(x))(A) det(A) @test tf(@net x -> det(x))(A) det(A)
A = randn(Float32,(6,3)) A = randn(Float32,(6,3))