tensorflow exception handling
This commit is contained in:
parent
4b82c57f88
commit
daf962a831
@ -2,7 +2,7 @@ using Base: @get!
|
|||||||
using DataFlow: Constant, constant, Split
|
using DataFlow: Constant, constant, Split
|
||||||
using DataFlow.Interpreter
|
using DataFlow.Interpreter
|
||||||
using Flux: imap
|
using Flux: imap
|
||||||
using TensorFlow: RawTensor
|
using TensorFlow: RawTensor, TFException
|
||||||
|
|
||||||
# TODO: implement Julia's type promotion rules
|
# TODO: implement Julia's type promotion rules
|
||||||
|
|
||||||
@ -70,3 +70,29 @@ end
|
|||||||
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[3]
|
TensorFlow.Tensor(m::Flux.Model, args...) = tograph(m, args...)[3]
|
||||||
|
|
||||||
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
||||||
|
|
||||||
|
# Error Handling
|
||||||
|
|
||||||
|
using Juno
|
||||||
|
using MacroTools: @q
|
||||||
|
using DataFlow.Interpreter: Exception, totrace
|
||||||
|
Juno.errmsg(e::TFException) = string(e.status)
|
||||||
|
|
||||||
|
function errnode(e::TFException)
|
||||||
|
m = match(r"Node: ([\w\d]+) =", string(e.status))
|
||||||
|
m == nothing && return
|
||||||
|
m.captures[1]
|
||||||
|
end
|
||||||
|
|
||||||
|
errnode(e) = nothing
|
||||||
|
|
||||||
|
macro tferr(stk, ex)
|
||||||
|
@q try
|
||||||
|
$(esc(ex))
|
||||||
|
catch e
|
||||||
|
(node = errnode(e)) != nothing || rethrow()
|
||||||
|
stk = $(esc(stk))
|
||||||
|
haskey(stk, node) || rethrow()
|
||||||
|
throw(Exception(e, totrace(stk[node])))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
@ -27,38 +27,17 @@ end
|
|||||||
|
|
||||||
storeparams!(m::Model) = storeparams!(m.session, m.params)
|
storeparams!(m::Model) = storeparams!(m.session, m.params)
|
||||||
|
|
||||||
ismultioutput(m::Model) = !isa(m.output, Tensor)
|
|
||||||
|
|
||||||
function tferr(model::Model, e)
|
|
||||||
m = match(r"Node: ([\w\d]+) =", string(e.status))
|
|
||||||
m == nothing && return
|
|
||||||
node = m.captures[1]
|
|
||||||
if haskey(model.stacks, node)
|
|
||||||
stk = model.stacks[node]
|
|
||||||
println("TensorFlow error occured at:")
|
|
||||||
foreach(l -> println("$(l.file):$(l.line)"), stk)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function runmodel(m::Model, args...)
|
function runmodel(m::Model, args...)
|
||||||
@assert length(args) == length(m.inputs)
|
@assert length(args) == length(m.inputs)
|
||||||
try
|
run(m.session, m.output, Dict(zip(m.inputs, args)))
|
||||||
output = run(m.session, m.output, Dict(zip(m.inputs, args)))
|
|
||||||
ismultioutput(m) ? (rebatch.(output)...,) : rebatch(output)
|
|
||||||
catch e
|
|
||||||
isa(e, TensorFlow.TFException) || rethrow(e)
|
|
||||||
tferr(m, e)
|
|
||||||
rethrow(e)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::Model)(args::Batch...)
|
using Flux: runrawbatched
|
||||||
runmodel(m, map(x -> convertel(Float32, x), args)...)
|
|
||||||
end
|
|
||||||
|
|
||||||
function (m::Model)(args...)
|
function (m::Model)(x)
|
||||||
output = m(map(batchone, args)...)
|
@tferr m.stacks runrawbatched(convertel(Float32, x)) do x
|
||||||
ismultioutput(m) ? map(first, output) : first(output)
|
output = runmodel(m, x)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
for f in :[back!, update!].args
|
for f in :[back!, update!].args
|
||||||
|
@ -18,4 +18,16 @@ dt = tf(d)
|
|||||||
@test run(sess, Y, Dict(X=>xs')) ≈ d(xs)'
|
@test run(sess, Y, Dict(X=>xs')) ≈ d(xs)'
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "Stack Traces" begin
|
||||||
|
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||||
|
info("The following warning is normal")
|
||||||
|
dm = tf(model)
|
||||||
|
e = try dm(rand(10))
|
||||||
|
catch e e end
|
||||||
|
|
||||||
|
@test isa(e, DataFlow.Interpreter.Exception)
|
||||||
|
@test e.trace[1].func == Symbol("Flux.Affine")
|
||||||
|
@test e.trace[2].func == :TLP
|
||||||
|
end
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user