tensorflow exception handling
This commit is contained in:
parent
4b82c57f88
commit
daf962a831
|
@ -2,7 +2,7 @@ using Base: @get!
|
|||
using DataFlow: Constant, constant, Split
|
||||
using DataFlow.Interpreter
|
||||
using Flux: imap
|
||||
using TensorFlow: RawTensor
|
||||
using TensorFlow: RawTensor, TFException
|
||||
|
||||
# TODO: implement Julia's type promotion rules
|
||||
|
||||
|
@ -70,3 +70,29 @@ end
|
|||
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[3]
|
||||
|
||||
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
||||
|
||||
# Error Handling
|
||||
|
||||
using Juno
|
||||
using MacroTools: @q
|
||||
using DataFlow.Interpreter: Exception, totrace
|
||||
Juno.errmsg(e::TFException) = string(e.status)
|
||||
|
||||
function errnode(e::TFException)
|
||||
m = match(r"Node: ([\w\d]+) =", string(e.status))
|
||||
m == nothing && return
|
||||
m.captures[1]
|
||||
end
|
||||
|
||||
errnode(e) = nothing
|
||||
|
||||
macro tferr(stk, ex)
|
||||
@q try
|
||||
$(esc(ex))
|
||||
catch e
|
||||
(node = errnode(e)) != nothing || rethrow()
|
||||
stk = $(esc(stk))
|
||||
haskey(stk, node) || rethrow()
|
||||
throw(Exception(e, totrace(stk[node])))
|
||||
end
|
||||
end
|
||||
|
|
|
@ -27,40 +27,19 @@ end
|
|||
|
||||
storeparams!(m::Model) = storeparams!(m.session, m.params)
|
||||
|
||||
ismultioutput(m::Model) = !isa(m.output, Tensor)
|
||||
|
||||
function tferr(model::Model, e)
|
||||
m = match(r"Node: ([\w\d]+) =", string(e.status))
|
||||
m == nothing && return
|
||||
node = m.captures[1]
|
||||
if haskey(model.stacks, node)
|
||||
stk = model.stacks[node]
|
||||
println("TensorFlow error occured at:")
|
||||
foreach(l -> println("$(l.file):$(l.line)"), stk)
|
||||
end
|
||||
end
|
||||
|
||||
function runmodel(m::Model, args...)
|
||||
@assert length(args) == length(m.inputs)
|
||||
try
|
||||
output = run(m.session, m.output, Dict(zip(m.inputs, args)))
|
||||
ismultioutput(m) ? (rebatch.(output)...,) : rebatch(output)
|
||||
catch e
|
||||
isa(e, TensorFlow.TFException) || rethrow(e)
|
||||
tferr(m, e)
|
||||
rethrow(e)
|
||||
run(m.session, m.output, Dict(zip(m.inputs, args)))
|
||||
end
|
||||
|
||||
using Flux: runrawbatched
|
||||
|
||||
function (m::Model)(x)
|
||||
@tferr m.stacks runrawbatched(convertel(Float32, x)) do x
|
||||
output = runmodel(m, x)
|
||||
end
|
||||
end
|
||||
|
||||
function (m::Model)(args::Batch...)
|
||||
runmodel(m, map(x -> convertel(Float32, x), args)...)
|
||||
end
|
||||
|
||||
function (m::Model)(args...)
|
||||
output = m(map(batchone, args)...)
|
||||
ismultioutput(m) ? map(first, output) : first(output)
|
||||
end
|
||||
|
||||
for f in :[back!, update!].args
|
||||
@eval function Flux.$f(m::Model, args...)
|
||||
error($(string(f)) * " is not yet supported on TensorFlow models")
|
||||
|
|
|
@ -18,4 +18,16 @@ dt = tf(d)
|
|||
@test run(sess, Y, Dict(X=>xs')) ≈ d(xs)'
|
||||
end
|
||||
|
||||
@testset "Stack Traces" begin
|
||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||
info("The following warning is normal")
|
||||
dm = tf(model)
|
||||
e = try dm(rand(10))
|
||||
catch e e end
|
||||
|
||||
@test isa(e, DataFlow.Interpreter.Exception)
|
||||
@test e.trace[1].func == Symbol("Flux.Affine")
|
||||
@test e.trace[2].func == :TLP
|
||||
end
|
||||
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue