diff --git a/src/backend/tensorflow/model.jl b/src/backend/tensorflow/model.jl index b6db200b..cdb7f5e5 100644 --- a/src/backend/tensorflow/model.jl +++ b/src/backend/tensorflow/model.jl @@ -13,9 +13,11 @@ function makesession(model, inputs; session = Session(Graph())) params, stacks, output = tograph(model, inputs...) output = mapt(x->Param{Tensor}(x, placeholder(Float32)), output) params = Dict(x=>Param{Tensor}(y, gradients(mapt(x->x.x, output), - y, mapt(x->x.Δx, output))) for (x, y) in params) + y, mapt(x->x.Δx, output))) + for (x, y) in params) inputs = mapt(x->Param{Tensor}(x, gradients(mapt(x->x.x, output), - x, mapt(x->x.Δx, output))), inputs) + x, mapt(x->x.Δx, output))), + inputs) run(session, global_variables_initializer()) Exec(session, inputs, output, params, stacks) end