tensorflow exception handling

This commit is contained in:
Mike J Innes 2017-03-12 18:34:11 +00:00
parent 4b82c57f88
commit daf962a831
3 changed files with 47 additions and 30 deletions

View File

@ -2,7 +2,7 @@ using Base: @get!
using DataFlow: Constant, constant, Split
using DataFlow.Interpreter
using Flux: imap
using TensorFlow: RawTensor
using TensorFlow: RawTensor, TFException
# TODO: implement Julia's type promotion rules
@ -70,3 +70,29 @@ end
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[3]
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
# Error Handling
using Juno
using MacroTools: @q
using DataFlow.Interpreter: Exception, totrace
Juno.errmsg(e::TFException) = string(e.status)
function errnode(e::TFException)
m = match(r"Node: ([\w\d]+) =", string(e.status))
m == nothing && return
m.captures[1]
end
errnode(e) = nothing
macro tferr(stk, ex)
@q try
$(esc(ex))
catch e
(node = errnode(e)) != nothing || rethrow()
stk = $(esc(stk))
haskey(stk, node) || rethrow()
throw(Exception(e, totrace(stk[node])))
end
end

View File

@ -27,40 +27,19 @@ end
storeparams!(m::Model) = storeparams!(m.session, m.params)
ismultioutput(m::Model) = !isa(m.output, Tensor)
function tferr(model::Model, e)
m = match(r"Node: ([\w\d]+) =", string(e.status))
m == nothing && return
node = m.captures[1]
if haskey(model.stacks, node)
stk = model.stacks[node]
println("TensorFlow error occured at:")
foreach(l -> println("$(l.file):$(l.line)"), stk)
end
end
function runmodel(m::Model, args...)
@assert length(args) == length(m.inputs)
try
output = run(m.session, m.output, Dict(zip(m.inputs, args)))
ismultioutput(m) ? (rebatch.(output)...,) : rebatch(output)
catch e
isa(e, TensorFlow.TFException) || rethrow(e)
tferr(m, e)
rethrow(e)
run(m.session, m.output, Dict(zip(m.inputs, args)))
end
using Flux: runrawbatched
function (m::Model)(x)
@tferr m.stacks runrawbatched(convertel(Float32, x)) do x
output = runmodel(m, x)
end
end
function (m::Model)(args::Batch...)
runmodel(m, map(x -> convertel(Float32, x), args)...)
end
function (m::Model)(args...)
output = m(map(batchone, args)...)
ismultioutput(m) ? map(first, output) : first(output)
end
for f in :[back!, update!].args
@eval function Flux.$f(m::Model, args...)
error($(string(f)) * " is not yet supported on TensorFlow models")

View File

@ -18,4 +18,16 @@ dt = tf(d)
@test run(sess, Y, Dict(X=>xs')) d(xs)'
end
@testset "Stack Traces" begin
model = TLP(Affine(10, 20), Affine(21, 15))
info("The following warning is normal")
dm = tf(model)
e = try dm(rand(10))
catch e e end
@test isa(e, DataFlow.Interpreter.Exception)
@test e.trace[1].func == Symbol("Flux.Affine")
@test e.trace[2].func == :TLP
end
end