Flux.jl/src/backend/tensorflow/graph.jl

126 lines
4.1 KiB
Julia
Raw Normal View History

using Base: @get!
2017-06-05 15:08:23 +00:00
using Flux: Reshape, MaxPool, flatten
2017-05-22 17:15:47 +00:00
using DataFlow: constant, Split
2017-02-01 06:57:02 +00:00
using DataFlow.Interpreter
2017-03-27 17:23:22 +00:00
using DataFlow.Interpreter: stack
2017-03-12 18:34:11 +00:00
using TensorFlow: RawTensor, TFException
2016-10-26 13:25:10 +00:00
2016-11-02 00:36:13 +00:00
# TODO: implement Julia's type promotion rules
node(x::Tuple) = map(node, x)
2016-11-15 21:09:58 +00:00
node(x::Tensor) = x
2016-12-15 20:53:15 +00:00
node(x::Variable) = x
node(x::Number) = TensorFlow.constant(Float32(x))
2016-10-26 13:25:10 +00:00
2016-11-15 21:09:58 +00:00
graph(::typeof(tuple), args...) = (args...,)
graph(s::Split, t::Tuple) = t[s.n]
2017-05-01 16:44:20 +00:00
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
2017-05-01 15:28:39 +00:00
graph(::typeof(identity), x) = TensorFlow.identity(x)
2016-10-26 13:25:10 +00:00
graph(::typeof(softmax), x) = nn.softmax(x)
graph(::typeof(relu), x) = nn.relu(x)
2016-10-28 14:13:43 +00:00
graph(::typeof(σ), x) = nn.sigmoid(x)
2016-11-17 11:28:15 +00:00
graph(::typeof(hcat), xs...) = concat(1, xs)
graph(::typeof(sum), x, dim=nothing) = TensorFlow.reduce_sum(x;axis=dim)
graph(::typeof(prod), x, dim=nothing) = TensorFlow.reduce_prod(x;axis=dim)
graph(::typeof(min), x, dim=nothing) = TensorFlow.reduce_min(x;axis=dim)
graph(::typeof(max), x, dim=nothing) = TensorFlow.reduce_max(x;axis=dim)
graph(::typeof(all), x, dim=nothing) = TensorFlow.reduce_all(x;axis=dim)
graph(::typeof(any), x, dim=nothing) = TensorFlow.reduce_any(x;axis=dim)
graph(::typeof(mean), x, dim=nothing) = TensorFlow.reduce_mean(x;axis=dim)
2017-06-08 05:05:31 +00:00
graph(::typeof(svd), x) = svd(x)
2017-06-09 19:42:38 +00:00
graph(::typeof(size), x, dim) = TensorFlow.size(x,convert(Tensor{Int32}, dim))
graph(::typeof(size), x) = TensorFlow.size(x)
2017-06-09 19:50:25 +00:00
graph(::typeof(chol), args...) = TensorFlow.transpose(TensorFlow.cholesky(args...))
2017-06-09 20:13:25 +00:00
graph(::typeof(reshape), x, dims) = TensorFlow.reshape(x,convert(Tensor{Int32},dims))
2016-11-15 21:09:58 +00:00
for op in (*, .*, .+, .^, log, exp, ceil, floor, sqrt, abs, cos,
2017-06-08 05:31:23 +00:00
sin, tan, atan, asin, acos, tanh, lgamma, erf, erfc, real, imag, conj,
2017-06-09 19:32:50 +00:00
inv, det, transpose, permutedims, cat, length, diag, diagm)
2017-01-27 11:33:22 +00:00
@eval graph(::typeof($op), args...) = $op(args...)
2016-11-15 21:09:58 +00:00
end
2016-10-26 13:25:10 +00:00
2017-05-24 11:02:03 +00:00
for op in (+, -, *, /)
@eval graph(::typeof(broadcast), ::typeof($op), args...) = broadcast($op, args...)
end
2017-01-27 11:33:22 +00:00
graph(::typeof(.-), args...) = -(args...)
2016-12-15 20:53:15 +00:00
2017-05-30 16:23:34 +00:00
graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
2016-10-26 13:25:10 +00:00
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)]))
graph(r::Reshape, x) = reshape(x, pack([batchsize(x), map(Int32, r.dims)...]))
graph(::Input, x) = x
graph(p::MaxPool, x) =
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
graph(op::Op, xs...) = op.f(xs...)
2016-12-20 16:31:20 +00:00
function graph(ctx::Context, model, args...)
2017-05-22 17:15:47 +00:00
node = graph(model, args...)
2017-03-14 15:21:18 +00:00
node isa Tensor && (ctx[:stacks][node.op.name] = stack(ctx))
2016-12-20 16:31:20 +00:00
return node
end
interp(ctx, c::Conv2D, x) =
2017-05-22 17:15:47 +00:00
nn.conv2d(x, interp(ctx, constant(c.filter)), [1,c.stride...,1], "VALID")
2016-10-26 13:25:10 +00:00
2017-06-05 15:09:06 +00:00
param(ctx, p::Flux.Param{<:AbstractArray}) =
2017-05-22 17:15:47 +00:00
haskey(ctx[:params], p) ?
ctx[:params][p] :
(ctx[:params][p] =
2017-05-01 17:27:52 +00:00
ctx[:variables] ?
2017-05-22 17:15:47 +00:00
Variable(Float32.(p.x)) :
2017-05-01 17:27:52 +00:00
placeholder(Float32))
2016-10-26 13:25:10 +00:00
2017-05-22 17:15:47 +00:00
param(ctx, x) = x
2016-12-13 15:46:34 +00:00
function interp(ctx, model, args...)
2017-05-22 17:15:47 +00:00
args = param.(ctx, args)
2016-10-28 14:13:58 +00:00
g = Flux.graph(model)
2016-12-20 16:31:20 +00:00
g == nothing && return graph(ctx, model, args...)
2016-10-31 12:38:18 +00:00
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
2017-05-22 17:15:47 +00:00
interpret(ctx, g, args...)
2016-10-28 15:06:56 +00:00
end
2017-05-01 17:27:52 +00:00
function tograph(model, args...; variables = false)
2017-05-22 17:15:47 +00:00
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, interp),
2017-05-01 17:27:52 +00:00
params = ObjectIdDict(), stacks = Dict(), variables = variables)
out = interp(ctx, model, map(constant, args)...)
2016-12-20 17:32:33 +00:00
return ctx[:params], ctx[:stacks], out
2016-10-28 14:13:58 +00:00
end
2017-06-05 15:32:16 +00:00
astensor(model, args...) =
tograph(model, args...; variables = true)[3]
2016-10-26 13:25:10 +00:00
2017-06-05 15:56:44 +00:00
RawTensor(data::Union{Flux.Batch,Flux.Seq}) = RawTensor(Flux.rawbatch(data))
2017-03-12 18:34:11 +00:00
# Error Handling
using Juno
using MacroTools: @q
using DataFlow.Interpreter: Exception, totrace
Juno.errmsg(e::TFException) = string(e.status)
function errnode(e::TFException)
m = match(r"Node: ([\w\d]+) =", string(e.status))
m == nothing && return
m.captures[1]
end
errnode(e) = nothing
macro tferr(stk, ex)
@q try
$(esc(ex))
catch e
(node = errnode(e)) != nothing || rethrow()
stk = $(esc(stk))
haskey(stk, node) || rethrow()
throw(Exception(e, totrace(stk[node])))
end
end