remove backends
This commit is contained in:
parent
4604e9f515
commit
536949891d
@ -38,8 +38,6 @@ include("layers/cost.jl")
|
|||||||
include("layers/recurrent.jl")
|
include("layers/recurrent.jl")
|
||||||
include("layers/shims.jl")
|
include("layers/shims.jl")
|
||||||
|
|
||||||
include("backend/backend.jl")
|
|
||||||
|
|
||||||
include("data.jl")
|
include("data.jl")
|
||||||
include("training.jl")
|
include("training.jl")
|
||||||
|
|
||||||
|
@ -1,28 +0,0 @@
|
|||||||
# We use a lazy-loading trick to load the backend code as needed; this avoids
|
|
||||||
# the need for a hard dependency on both backends.
|
|
||||||
|
|
||||||
# This is effectively equivalent to:
|
|
||||||
# include("tensorflow/tensorflow.jl")
|
|
||||||
# using .TF
|
|
||||||
# export tf
|
|
||||||
# but instead of loading immediately, we wait until `tf` is first called.
|
|
||||||
|
|
||||||
function loadtf()
|
|
||||||
isdefined(Flux, :TF) && return
|
|
||||||
@eval include(joinpath(dirname($@__FILE__), "tensorflow/tensorflow.jl"))
|
|
||||||
end
|
|
||||||
|
|
||||||
function tf(args...)
|
|
||||||
loadtf()
|
|
||||||
eval(:(TF.tf($(QuoteNode.(args)...))))
|
|
||||||
end
|
|
||||||
|
|
||||||
function loadmx()
|
|
||||||
isdefined(Flux, :MX) && return
|
|
||||||
@eval include(joinpath(dirname($@__FILE__), "mxnet/mxnet.jl"))
|
|
||||||
end
|
|
||||||
|
|
||||||
function mxnet(args...)
|
|
||||||
loadmx()
|
|
||||||
eval(:(MX.mxnet($(QuoteNode.(args)...))))
|
|
||||||
end
|
|
@ -1,142 +0,0 @@
|
|||||||
function nodename(s::mx.SymbolicNode)
|
|
||||||
name = Ref{mx.char_p}(0)
|
|
||||||
success = Ref(0)
|
|
||||||
mx.@mxcall(:MXSymbolGetName, (mx.MX_handle, Ref{mx.char_p}, Ref{Int}), s.handle.value, name, success)
|
|
||||||
@assert success[] != -1
|
|
||||||
return Symbol(unsafe_string(name[]))
|
|
||||||
end
|
|
||||||
|
|
||||||
using Base: @get!
|
|
||||||
using DataFlow: Constant, constant
|
|
||||||
using DataFlow.Interpreter
|
|
||||||
using DataFlow.Interpreter: Exception, totrace
|
|
||||||
import Flux: Reshape, MaxPool, flatten, mapt, broadcastto, ∘
|
|
||||||
|
|
||||||
# TODO: implement Julia's type promotion rules
|
|
||||||
|
|
||||||
node(x::Tuple) = map(node, x)
|
|
||||||
node(x::mx.SymbolicNode) = x
|
|
||||||
|
|
||||||
graph(::typeof(tuple), args...) = (args...,)
|
|
||||||
graph(::typeof(identity), x) = x
|
|
||||||
graph(::typeof(*), xs...) = mx.dot(reverse(xs)...) # Work around MXNet shape hack
|
|
||||||
graph(::typeof(σ), x) = mx.Activation(x, act_type = :sigmoid)
|
|
||||||
graph(::typeof(relu), x) = mx.Activation(x, act_type = :relu)
|
|
||||||
graph(::typeof(tanh), x) = mx.Activation(x, act_type = :tanh)
|
|
||||||
graph(::typeof(flatten), x) = mx.Flatten(x)
|
|
||||||
graph(::typeof(hcat), xs...) = mx.concat(xs..., dim = 2-1)
|
|
||||||
graph(::typeof(vec), xs) = reshape(xs, shape = (-1,))
|
|
||||||
|
|
||||||
graph(::typeof(broadcast), ::typeof(+), args...) = mx.broadcast_plus(args...)
|
|
||||||
graph(::typeof(broadcast), ::typeof(-), args...) = mx.broadcast_sub(args...)
|
|
||||||
graph(::typeof(broadcast), ::typeof(*), args...) = mx.broadcast_mul(args...)
|
|
||||||
graph(::typeof(broadcast), ::typeof(/), args...) = mx.broadcast_div(args...)
|
|
||||||
graph(::typeof(broadcastto), xs, shape) = mx.broadcast_to(xs, shape = map(i -> i≤1?0:i, reverse(shape)))
|
|
||||||
# Old broadcasters
|
|
||||||
graph(::typeof(broadcast), ::typeof(exp), xs) = exp(xs)
|
|
||||||
graph(::typeof(.+), args...) = mx.broadcast_plus(args...)
|
|
||||||
graph(::typeof(.*), args...) = mx.broadcast_mul(args...)
|
|
||||||
graph(::typeof(.-), args...) = mx.broadcast_sub(args...)
|
|
||||||
|
|
||||||
graph(::typeof(softmax), xs) =
|
|
||||||
mx.broadcast_div(exp(xs), mx.sum(exp(xs), axis = 1, keepdims=true))
|
|
||||||
|
|
||||||
graph(::typeof(cat), dim::Integer, a...) = mx.Concat(a..., dim = dim)
|
|
||||||
graph(::typeof(vcat), a...) = graph(cat, 1, a...)
|
|
||||||
|
|
||||||
graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
|
|
||||||
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
|
|
||||||
graph(::typeof(sum), xs::Tuple) = reduce((a, b) -> graph(broadcast, +, a, b), xs)
|
|
||||||
graph(::typeof(Base.Iterators.repeated), x, n) = ntuple(_ -> x, n)
|
|
||||||
|
|
||||||
a::mx.SymbolicNode ∘ b::mx.SymbolicNode = mx.broadcast_mul(a, b)
|
|
||||||
|
|
||||||
graph(::Input, x) = x
|
|
||||||
|
|
||||||
struct AlterParam
|
|
||||||
param
|
|
||||||
load
|
|
||||||
store
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.size(p::AlterParam) = size(p.load(p.param.x))
|
|
||||||
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
|
|
||||||
|
|
||||||
graph(ctx::Context, d::Affine, x) =
|
|
||||||
!ctx[:feedforward] ? invoke(graph, Tuple{Context, Any, typeof(x)}, ctx, d, x) :
|
|
||||||
register(ctx,
|
|
||||||
mx.FullyConnected(mx.SymbolicNode, data = x,
|
|
||||||
num_hidden = size(d.W.x, 2),
|
|
||||||
weight = var(ctx, AlterParam(d.W, x->x', nothing)),
|
|
||||||
bias = var(ctx, AlterParam(d.b, x->squeeze(x, 1), nothing))))
|
|
||||||
|
|
||||||
# TODO: use actual params
|
|
||||||
graph(ctx::Context, c::Conv2D, x) =
|
|
||||||
mx.Convolution(x,
|
|
||||||
kernel = size(c.filter, 1, 2),
|
|
||||||
num_filter = size(c.filter, 4),
|
|
||||||
stride = c.stride)
|
|
||||||
|
|
||||||
graph(ctx::Context, p::MaxPool, x) =
|
|
||||||
mx.Pooling(x,
|
|
||||||
pool_type = :max,
|
|
||||||
kernel = p.size,
|
|
||||||
stride = p.stride)
|
|
||||||
|
|
||||||
function register(ctx::Context, node::mx.SymbolicNode)
|
|
||||||
ctx[:stacks][nodename(node)] = stack(ctx)
|
|
||||||
return node
|
|
||||||
end
|
|
||||||
|
|
||||||
register(ctx::Context, node) = node
|
|
||||||
|
|
||||||
function var(ctx::Context, p::Union{Flux.Param{<:AbstractArray},AbstractArray,AlterParam})
|
|
||||||
haskey(ctx[:params], p) && return ctx[:params][p]
|
|
||||||
ctx[:params][p] = mx.Variable(gensym())
|
|
||||||
end
|
|
||||||
|
|
||||||
var(ctx::Context, x) = x
|
|
||||||
|
|
||||||
function graph(ctx::Context, model, args...)
|
|
||||||
args = var.(ctx, args)
|
|
||||||
g = Flux.graph(model)
|
|
||||||
g == nothing && return register(ctx, @icatch ctx graph(model, args...))
|
|
||||||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
|
||||||
interpret(ctx, g, args...)
|
|
||||||
end
|
|
||||||
|
|
||||||
graph′(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
|
|
||||||
|
|
||||||
function tograph(model, args...; feedforward = false)
|
|
||||||
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, graph′),
|
|
||||||
params = ObjectIdDict(), stacks = Dict(),
|
|
||||||
feedforward = feedforward)
|
|
||||||
out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)
|
|
||||||
params = Dict(nodename(v) => p for (p, v) in ctx[:params])
|
|
||||||
return Graph(args, out, params, ctx[:stacks])
|
|
||||||
end
|
|
||||||
|
|
||||||
# Error Handling
|
|
||||||
|
|
||||||
using Juno
|
|
||||||
using MacroTools: @q
|
|
||||||
Juno.errmsg(e::mx.MXError) = e.msg
|
|
||||||
|
|
||||||
function errnode(e::mx.MXError)
|
|
||||||
m = match(r"Error in operator (\w+)", e.msg)
|
|
||||||
m == nothing && return
|
|
||||||
Symbol(m.captures[1])
|
|
||||||
end
|
|
||||||
|
|
||||||
striptrace(e::mx.MXError) = mx.MXError(split(e.msg, "\n")[1])
|
|
||||||
|
|
||||||
macro mxerr(stk, ex)
|
|
||||||
@q try
|
|
||||||
$(esc(ex))
|
|
||||||
catch e
|
|
||||||
(e isa mx.MXError && (node = errnode(e)) != nothing) || rethrow()
|
|
||||||
stk = $(esc(stk))
|
|
||||||
haskey(stk, node) || rethrow()
|
|
||||||
throw(Exception(striptrace(e), totrace(stk[node])))
|
|
||||||
end
|
|
||||||
end
|
|
@ -1,159 +0,0 @@
|
|||||||
using Flux: collectt, shapecheckt, back!, update!
|
|
||||||
|
|
||||||
function copyargs!(as, bs)
|
|
||||||
for id in intersect(keys(as), keys(bs))
|
|
||||||
copy!(as[id], bs[id])
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
struct Graph
|
|
||||||
input
|
|
||||||
output
|
|
||||||
params::Dict{Symbol,Any}
|
|
||||||
stacks::Dict{Any,Any}
|
|
||||||
end
|
|
||||||
|
|
||||||
function mxparams(ps, ctx)
|
|
||||||
params = Dict{Symbol,MXArray}()
|
|
||||||
for (name, param) in ps
|
|
||||||
params[name] = MXArray(size(param), ctx)
|
|
||||||
end
|
|
||||||
return params
|
|
||||||
end
|
|
||||||
|
|
||||||
ndparams(d) = Dict{Symbol,mx.NDArray}(k => v.data for (k, v) in d)
|
|
||||||
|
|
||||||
struct Exec
|
|
||||||
graph::Graph
|
|
||||||
ctx::mx.Context
|
|
||||||
exec::mx.Executor
|
|
||||||
args::Dict{Symbol,MXArray}
|
|
||||||
grads::Dict{Symbol,MXArray}
|
|
||||||
outs::Vector{MXArray}
|
|
||||||
end
|
|
||||||
|
|
||||||
loadparams!(exec::Exec) = copyargs!(exec.args, exec.graph.params)
|
|
||||||
storeparams!(exec::Exec) = copyargs!(exec.graph.params, exec.args)
|
|
||||||
|
|
||||||
mxgroup(x) = x
|
|
||||||
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
|
|
||||||
mxungroup(x, outs) = copy(shift!(outs))
|
|
||||||
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
|
||||||
|
|
||||||
dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
|
|
||||||
|
|
||||||
function executor(graph::Graph, input...; ctx = mx.cpu())
|
|
||||||
shapecheckt(graph.input, input)
|
|
||||||
args = merge(mxparams(graph.params, ctx), dictt(graph.input, mapt(d->MXArray(size(d), ctx), input)))
|
|
||||||
grads = filter((a, b) -> b isa Flux.Param, graph.params)
|
|
||||||
grads = merge(mxparams(grads, ctx), dictt(graph.input, mapt(d->MXArray(size(d), ctx), input)))
|
|
||||||
exec = mx.bind(mxgroup(graph.output),
|
|
||||||
context = ctx,
|
|
||||||
args = ndparams(args),
|
|
||||||
args_grad = ndparams(grads),
|
|
||||||
grad_req = mx.GRAD_ADD)
|
|
||||||
exec = Exec(graph, ctx, exec, args, grads, MXArray.(exec.outputs))
|
|
||||||
loadparams!(exec)
|
|
||||||
return exec
|
|
||||||
end
|
|
||||||
|
|
||||||
function (exec::Exec)(input...)
|
|
||||||
foreach(kv -> copy!(exec.args[kv[1]], kv[2]), dictt(exec.graph.input, input))
|
|
||||||
mx.forward(exec.exec, is_train = true)
|
|
||||||
mxungroup(exec.graph.output, copy(exec.outs))
|
|
||||||
end
|
|
||||||
|
|
||||||
function Flux.back!(exec::Exec, Δ)
|
|
||||||
mapt(k -> exec.grads[k][:] = 0, exec.graph.input)
|
|
||||||
mx.backward(exec.exec, map(x -> MXArray(x, exec.ctx).data, collectt(Δ)))
|
|
||||||
mapt(k -> copy(exec.grads[k]), exec.graph.input)
|
|
||||||
end
|
|
||||||
|
|
||||||
function Flux.update!(exec::Exec, η)
|
|
||||||
for (arg, grad) in zip(exec.exec.arg_arrays, exec.exec.grad_arrays)
|
|
||||||
grad == nothing && continue
|
|
||||||
mx.@nd_as_jl rw = (arg, grad) begin
|
|
||||||
arg .-= grad .* η
|
|
||||||
grad[:] = 0
|
|
||||||
end
|
|
||||||
end
|
|
||||||
storeparams!(exec)
|
|
||||||
return exec
|
|
||||||
end
|
|
||||||
|
|
||||||
toctx(ctx::mx.Context) = ctx
|
|
||||||
toctx(c::Symbol) = c == :gpu ? mx.gpu() : mx.cpu()
|
|
||||||
|
|
||||||
# TODO: if `last` changes, update params appropriately
|
|
||||||
|
|
||||||
mutable struct Model
|
|
||||||
model::Any
|
|
||||||
ctx::mx.Context
|
|
||||||
execs::Dict{Tuple,Exec}
|
|
||||||
graph::Graph
|
|
||||||
last::Exec
|
|
||||||
Model(model, ctx) = new(model, ctx, Dict())
|
|
||||||
end
|
|
||||||
|
|
||||||
mxnet(model, ctx = :cpu) = Model(model, toctx(ctx))
|
|
||||||
|
|
||||||
function Base.show(io::IO, m::Model)
|
|
||||||
print(io, "MX.Model(")
|
|
||||||
show(io, m.model)
|
|
||||||
print(io, ", ")
|
|
||||||
show(io, m.ctx)
|
|
||||||
print(io, ")")
|
|
||||||
end
|
|
||||||
|
|
||||||
import Base: @get!
|
|
||||||
|
|
||||||
# TODO: dims having its own type would be useful
|
|
||||||
executor(m::Model, input...) =
|
|
||||||
@get!(m.execs, mapt(size, input),
|
|
||||||
executor(m.graph, input...; ctx = m.ctx))
|
|
||||||
|
|
||||||
function (m::Model)(xs...)
|
|
||||||
@mxerr m.graph.stacks begin
|
|
||||||
!isdefined(m, :graph) &&
|
|
||||||
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...))
|
|
||||||
m.last = exec = executor(m, xs...)
|
|
||||||
exec(xs...)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function Flux.back!(m::Model, Δ, xs...)
|
|
||||||
m.last = exec = m.execs[mapt(size, xs)]
|
|
||||||
back!(exec, Δ)
|
|
||||||
end
|
|
||||||
|
|
||||||
Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
|
||||||
|
|
||||||
# Recurrent Models
|
|
||||||
|
|
||||||
using Flux: Stateful, SeqModel
|
|
||||||
|
|
||||||
mxnet(m::Stateful, a...) = Stateful(mxnet(m.model, a...), m.states, m.istate, m.ostate)
|
|
||||||
mxnet(m::SeqModel, a...) = SeqModel(mxnet(m.model, a...), m.steps)
|
|
||||||
|
|
||||||
# MX FeedForward interface
|
|
||||||
|
|
||||||
struct SoftmaxOutput
|
|
||||||
name::Symbol
|
|
||||||
end
|
|
||||||
|
|
||||||
graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(xs, name = s.name)
|
|
||||||
|
|
||||||
function rewrite_softmax(model, name)
|
|
||||||
model == softmax && return SoftmaxOutput(name)
|
|
||||||
g = Flux.graph(model)
|
|
||||||
(g == nothing || g.value ≠ softmax || DataFlow.nin(g) ≠ 1) && error("mx.FeedForward models must end with `softmax`")
|
|
||||||
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
|
|
||||||
end
|
|
||||||
|
|
||||||
function FeedForward(model; input = :data, label = :softmax, ctx = mx.cpu())
|
|
||||||
model = rewrite_softmax(model, label)
|
|
||||||
graph = tograph(model, input, feedforward=true)
|
|
||||||
ff = mx.FeedForward(graph.output, context = ctx)
|
|
||||||
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params, ctx)))
|
|
||||||
return ff
|
|
||||||
end
|
|
@ -1,40 +0,0 @@
|
|||||||
using MXNet
|
|
||||||
|
|
||||||
# NDArray is row-major so by default all dimensions are reversed in MXNet.
|
|
||||||
# MXArray tranposes when loading/storing to fix this.
|
|
||||||
|
|
||||||
reversedims!(dest, xs) = permutedims!(dest, xs, ndims(xs):-1:1)
|
|
||||||
|
|
||||||
struct MXArray{N}
|
|
||||||
data::mx.NDArray
|
|
||||||
scratch::Array{Float32,N}
|
|
||||||
end
|
|
||||||
|
|
||||||
MXArray(data::mx.NDArray) = MXArray(data, Array{Float32}(size(data)))
|
|
||||||
|
|
||||||
# TODO: split cpu/gpu mxarrays
|
|
||||||
MXArray(dims::Dims, ctx = mx.cpu()) = MXArray(mx.zeros(reverse(dims), ctx))
|
|
||||||
|
|
||||||
Base.size(xs::MXArray) = reverse(size(xs.data))
|
|
||||||
|
|
||||||
function Base.copy!(mx::MXArray, xs::AbstractArray)
|
|
||||||
@assert size(mx) == size(xs)
|
|
||||||
reversedims!(mx.scratch, xs)
|
|
||||||
copy!(mx.data, mx.scratch)
|
|
||||||
return mx
|
|
||||||
end
|
|
||||||
|
|
||||||
function Base.copy!(xs::AbstractArray, mx::MXArray)
|
|
||||||
@assert size(xs) == size(mx)
|
|
||||||
copy!(mx.scratch, mx.data)
|
|
||||||
reversedims!(xs, mx.scratch)
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.copy(mx::MXArray) = copy!(Array{Float32}(size(mx)), mx)
|
|
||||||
|
|
||||||
function MXArray(xs::AbstractArray, ctx = mx.cpu())
|
|
||||||
mx = MXArray(size(xs), ctx)
|
|
||||||
copy!(mx, xs)
|
|
||||||
end
|
|
||||||
|
|
||||||
Base.setindex!(xs::MXArray, x::Real, ::Colon) = xs.data[:] = x
|
|
@ -1,11 +0,0 @@
|
|||||||
module MX
|
|
||||||
|
|
||||||
using MXNet, DataFlow, ..Flux
|
|
||||||
|
|
||||||
export mxnet
|
|
||||||
|
|
||||||
include("mxarray.jl")
|
|
||||||
include("graph.jl")
|
|
||||||
include("model.jl")
|
|
||||||
|
|
||||||
end
|
|
@ -1,133 +0,0 @@
|
|||||||
using Base: @get!
|
|
||||||
using Flux: Reshape, MaxPool, flatten
|
|
||||||
using DataFlow: constant, Split
|
|
||||||
using DataFlow.Interpreter
|
|
||||||
using DataFlow.Interpreter: stack
|
|
||||||
using TensorFlow: RawTensor, TFException
|
|
||||||
|
|
||||||
# TODO: implement Julia's type promotion rules
|
|
||||||
|
|
||||||
node(x::Tuple) = map(node, x)
|
|
||||||
node(x::Tensor) = x
|
|
||||||
node(x::Variable) = x
|
|
||||||
node(x::Number) = TensorFlow.constant(Float32(x))
|
|
||||||
|
|
||||||
graph(::typeof(tuple), args...) = (args...,)
|
|
||||||
graph(s::Split, t::Tuple) = t[s.n]
|
|
||||||
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
|
|
||||||
graph(::typeof(identity), x) = TensorFlow.identity(x)
|
|
||||||
graph(::typeof(softmax), x) = nn.softmax(x)
|
|
||||||
graph(::typeof(relu), x) = nn.relu(x)
|
|
||||||
graph(::typeof(σ), x) = nn.sigmoid(x)
|
|
||||||
graph(::typeof(hcat), xs...) = concat(1, xs)
|
|
||||||
graph(::typeof(sum), x, dim=nothing) = TensorFlow.reduce_sum(x;axis=dim)
|
|
||||||
graph(::typeof(prod), x, dim=nothing) = TensorFlow.reduce_prod(x;axis=dim)
|
|
||||||
graph(::typeof(min), x, dim=nothing) = TensorFlow.reduce_min(x;axis=dim)
|
|
||||||
graph(::typeof(max), x, dim=nothing) = TensorFlow.reduce_max(x;axis=dim)
|
|
||||||
graph(::typeof(all), x, dim=nothing) = TensorFlow.reduce_all(x;axis=dim)
|
|
||||||
graph(::typeof(any), x, dim=nothing) = TensorFlow.reduce_any(x;axis=dim)
|
|
||||||
graph(::typeof(mean), x, dim=nothing) = TensorFlow.reduce_mean(x;axis=dim)
|
|
||||||
graph(::typeof(svd), x) = svd(x)
|
|
||||||
graph(::typeof(size), x, dim) = TensorFlow.size(x,convert(Tensor{Int32}, dim))
|
|
||||||
graph(::typeof(size), x) = TensorFlow.size(x)
|
|
||||||
graph(::typeof(chol), args...) = TensorFlow.transpose(TensorFlow.cholesky(args...))
|
|
||||||
graph(::typeof(reshape), x, dims) = TensorFlow.reshape(x,convert(Tensor{Int32},dims))
|
|
||||||
graph(::typeof(Flux.tile), args...) = TensorFlow.tile(args...)
|
|
||||||
graph(::typeof(fill), x, dims) = Ops.fill(convert(Tensor{Int32}, dims), Tensor(x))
|
|
||||||
graph(::typeof(Flux.cast), args...) = TensorFlow.cast(args...)
|
|
||||||
graph(::typeof(solve), A, b) = TensorFlow.matrix_solve(A, b)
|
|
||||||
graph(::typeof(triangular_solve), A, b) = TensorFlow.matrix_triangular_solve(A, b; lower=false)
|
|
||||||
graph(::typeof(randu), x) = Ops.random_uniform(convert(Tensor{Int32},x);dtype=Float32)
|
|
||||||
graph(::typeof(randn), x) = TensorFlow.random_normal(convert(Tensor{Int32},x);dtype=Float32)
|
|
||||||
graph(::typeof(Flux.expand_dims), x, dim) = TensorFlow.expand_dims(x,convert(Tensor{Int32},dim))
|
|
||||||
|
|
||||||
for op in (*, .*, .+, .^, log, exp, ceil, floor, sqrt, abs, cos,
|
|
||||||
sin, tan, atan, asin, acos, tanh, lgamma, erf, erfc, real, imag, conj,
|
|
||||||
inv, det, transpose, permutedims, cat, length, diag, diagm)
|
|
||||||
@eval graph(::typeof($op), args...) = $op(args...)
|
|
||||||
end
|
|
||||||
|
|
||||||
for op in (+, -, *, /)
|
|
||||||
@eval graph(::typeof(broadcast), ::typeof($op), args...) = broadcast($op, args...)
|
|
||||||
end
|
|
||||||
|
|
||||||
graph(::typeof(.-), args...) = -(args...)
|
|
||||||
|
|
||||||
graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
|
|
||||||
|
|
||||||
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
|
||||||
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
|
||||||
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)]))
|
|
||||||
graph(r::Reshape, x) = reshape(x, pack([batchsize(x), map(Int32, r.dims)...]))
|
|
||||||
|
|
||||||
graph(::Input, x) = x
|
|
||||||
|
|
||||||
graph(p::MaxPool, x) =
|
|
||||||
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
|
|
||||||
|
|
||||||
graph(op::Op, xs...) = op.f(xs...)
|
|
||||||
|
|
||||||
function graph(ctx::Context, model, args...)
|
|
||||||
node = graph(model, args...)
|
|
||||||
node isa Tensor && (ctx[:stacks][node.op.name] = stack(ctx))
|
|
||||||
return node
|
|
||||||
end
|
|
||||||
|
|
||||||
interp(ctx, c::Conv2D, x) =
|
|
||||||
nn.conv2d(x, interp(ctx, constant(c.filter)), [1,c.stride...,1], "VALID")
|
|
||||||
|
|
||||||
param(ctx, p::Flux.Param{<:AbstractArray}) =
|
|
||||||
haskey(ctx[:params], p) ?
|
|
||||||
ctx[:params][p] :
|
|
||||||
(ctx[:params][p] =
|
|
||||||
ctx[:variables] ?
|
|
||||||
Variable(Float32.(p.x)) :
|
|
||||||
placeholder(Float32))
|
|
||||||
|
|
||||||
param(ctx, x) = x
|
|
||||||
|
|
||||||
function interp(ctx, model, args...)
|
|
||||||
args = param.(ctx, args)
|
|
||||||
g = Flux.graph(model)
|
|
||||||
g == nothing && return graph(ctx, model, args...)
|
|
||||||
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
|
||||||
interpret(ctx, g, args...)
|
|
||||||
end
|
|
||||||
|
|
||||||
function tograph(model, args...; variables = false)
|
|
||||||
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, interp),
|
|
||||||
params = ObjectIdDict(), stacks = Dict(), variables = variables)
|
|
||||||
out = interp(ctx, model, map(constant, args)...)
|
|
||||||
return ctx[:params], ctx[:stacks], out
|
|
||||||
end
|
|
||||||
|
|
||||||
astensor(model, args...) =
|
|
||||||
tograph(model, args...; variables = true)[3]
|
|
||||||
|
|
||||||
RawTensor(data::Union{Flux.Batch,Flux.Seq}) = RawTensor(Flux.rawbatch(data))
|
|
||||||
|
|
||||||
# Error Handling
|
|
||||||
|
|
||||||
using Juno
|
|
||||||
using MacroTools: @q
|
|
||||||
using DataFlow.Interpreter: Exception, totrace
|
|
||||||
Juno.errmsg(e::TFException) = string(e.status)
|
|
||||||
|
|
||||||
function errnode(e::TFException)
|
|
||||||
m = match(r"Node: ([\w\d]+) =", string(e.status))
|
|
||||||
m == nothing && return
|
|
||||||
m.captures[1]
|
|
||||||
end
|
|
||||||
|
|
||||||
errnode(e) = nothing
|
|
||||||
|
|
||||||
macro tferr(stk, ex)
|
|
||||||
@q try
|
|
||||||
$(esc(ex))
|
|
||||||
catch e
|
|
||||||
(node = errnode(e)) != nothing || rethrow()
|
|
||||||
stk = $(esc(stk))
|
|
||||||
haskey(stk, node) || rethrow()
|
|
||||||
throw(Exception(e, totrace(stk[node])))
|
|
||||||
end
|
|
||||||
end
|
|
@ -1,86 +0,0 @@
|
|||||||
using Flux: Param, mapt, collectt, shapecheckt
|
|
||||||
|
|
||||||
struct Exec
|
|
||||||
session ::Session
|
|
||||||
input ::Any
|
|
||||||
output ::Any
|
|
||||||
params ::Dict{Param,Param{Tensor}}
|
|
||||||
stacks ::Dict{Any,Any}
|
|
||||||
end
|
|
||||||
|
|
||||||
dummy(x::Void) = TensorFlow.constant(0)
|
|
||||||
dummy(x::Tensor) = x
|
|
||||||
|
|
||||||
function makesession(model, inputs; session = Session(Graph()))
|
|
||||||
inputs = mapt(_ -> placeholder(Float32), inputs)
|
|
||||||
params, stacks, output = tograph(model, inputs...)
|
|
||||||
output = mapt(x->Param{Tensor}(x, placeholder(Float32)), output)
|
|
||||||
params = Dict(x=>Param{Tensor}(y, dummy(gradients(map(x->x.x, collectt(output)),
|
|
||||||
y, map(x->x.Δx, collectt(output)))))
|
|
||||||
for (x, y) in params)
|
|
||||||
inputs = mapt(x->Param{Tensor}(x, dummy(gradients(map(x->x.x, collectt(output)),
|
|
||||||
x, map(x->x.Δx, collectt(output))))),
|
|
||||||
inputs)
|
|
||||||
run(session, global_variables_initializer())
|
|
||||||
Exec(session, inputs, output, params, stacks)
|
|
||||||
end
|
|
||||||
|
|
||||||
retuple(xs) = xs
|
|
||||||
retuple(xs::AbstractArray{<:AbstractArray}) = (retuple.(xs)...,)
|
|
||||||
|
|
||||||
dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
|
|
||||||
|
|
||||||
function (m::Exec)(args...)
|
|
||||||
dict = merge(
|
|
||||||
Dict(y.x=>x.x for (x, y) in m.params),
|
|
||||||
Dict(x.x=>y for (x, y) in dictt(m.input, args))
|
|
||||||
)
|
|
||||||
retuple(run(m.session, mapt(x->x.x, m.output), dict))
|
|
||||||
end
|
|
||||||
|
|
||||||
function Flux.back!(m::Exec, Δ, args...)
|
|
||||||
dict = merge(
|
|
||||||
Dict(y.x=>x.x for (x, y) in m.params),
|
|
||||||
Dict(x.x=>y for (x, y) in zip(m.input, args)),
|
|
||||||
Dict(x.Δx=>y for (x, y) in zip(collectt(m.output), collectt(Δ)))
|
|
||||||
)
|
|
||||||
|
|
||||||
Δin, Δps = run(m.session, (mapt(x->x.Δx, m.input), map(x->x.Δx, values(m.params))), dict)
|
|
||||||
|
|
||||||
for (p, Δ) in zip(keys(m.params), Δps)
|
|
||||||
p.Δx .+= Δ
|
|
||||||
end
|
|
||||||
|
|
||||||
Δin
|
|
||||||
end
|
|
||||||
|
|
||||||
function Flux.update!(m::Exec, η)
|
|
||||||
for p in keys(m.params)
|
|
||||||
Flux.update!(p, η)
|
|
||||||
end
|
|
||||||
return m
|
|
||||||
end
|
|
||||||
|
|
||||||
mutable struct Model
|
|
||||||
model::Any
|
|
||||||
exec::Exec
|
|
||||||
Model(model) = new(model)
|
|
||||||
end
|
|
||||||
|
|
||||||
tf(model) = Model(model)
|
|
||||||
|
|
||||||
function (m::Model)(args...)
|
|
||||||
args = mapt(x->Float32.(x), args)
|
|
||||||
isdefined(m, :exec) || (m.exec = makesession(m.model, args))
|
|
||||||
@tferr m.exec.stacks m.exec(args...)
|
|
||||||
end
|
|
||||||
|
|
||||||
Flux.back!(m::Model, Δ, args...) = Flux.back!(m.exec, Δ, args...)
|
|
||||||
Flux.update!(m::Model, η) = (Flux.update!(m.exec, η); m)
|
|
||||||
|
|
||||||
# Recurrent Models
|
|
||||||
|
|
||||||
using Flux: Stateful, SeqModel
|
|
||||||
|
|
||||||
tf(m::Stateful) = Stateful(tf(m.model), m.states, m.istate, m.ostate)
|
|
||||||
tf(m::SeqModel) = SeqModel(tf(m.model), m.steps)
|
|
@ -1,20 +0,0 @@
|
|||||||
module TF
|
|
||||||
|
|
||||||
using ..Flux, DataFlow, TensorFlow, Juno
|
|
||||||
import Flux: accuracy, convertel
|
|
||||||
|
|
||||||
export tf
|
|
||||||
|
|
||||||
struct Op
|
|
||||||
f
|
|
||||||
shape
|
|
||||||
end
|
|
||||||
|
|
||||||
Op(f) = Op(f, (d...) -> nothing)
|
|
||||||
|
|
||||||
Flux.shape(op::Op, d...) = op.shape(d...)
|
|
||||||
|
|
||||||
include("graph.jl")
|
|
||||||
include("model.jl")
|
|
||||||
|
|
||||||
end
|
|
@ -1,39 +0,0 @@
|
|||||||
using MXNet
|
|
||||||
Flux.loadmx()
|
|
||||||
|
|
||||||
@testset "MXNet" begin
|
|
||||||
|
|
||||||
xs, ys = rand(1, 20), rand(1, 20)
|
|
||||||
d = Affine(20, 10)
|
|
||||||
|
|
||||||
dm = mxnet(d)
|
|
||||||
@test d(xs) ≈ dm(xs)
|
|
||||||
|
|
||||||
test_tupleio(mxnet)
|
|
||||||
test_recurrence(mxnet)
|
|
||||||
test_stacktrace(mxnet)
|
|
||||||
test_back(mxnet)
|
|
||||||
test_anon(mxnet)
|
|
||||||
|
|
||||||
using Flux: MaxPool
|
|
||||||
|
|
||||||
@testset "Native interface" begin
|
|
||||||
f = Flux.MX.FeedForward(Chain(d, softmax))
|
|
||||||
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
|
|
||||||
|
|
||||||
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
|
|
||||||
flatten, Affine(1587, 10), softmax)
|
|
||||||
f = Flux.MX.FeedForward(m)
|
|
||||||
# TODO: test run
|
|
||||||
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Duplicate parameters" begin
|
|
||||||
a = Affine(10, 10)
|
|
||||||
d = Chain(a, a)
|
|
||||||
m = mxnet(d)
|
|
||||||
m(randn(1, 10))
|
|
||||||
@test length(m.graph.params) == 2
|
|
||||||
end
|
|
||||||
|
|
||||||
end
|
|
@ -1,70 +0,0 @@
|
|||||||
using TensorFlow
|
|
||||||
Flux.loadtf()
|
|
||||||
|
|
||||||
@testset "TensorFlow" begin
|
|
||||||
|
|
||||||
xs, ys = rand(1, 20), rand(1, 20)
|
|
||||||
d = Affine(20, 10)
|
|
||||||
|
|
||||||
dt = tf(d)
|
|
||||||
@test d(xs) ≈ dt(xs)
|
|
||||||
|
|
||||||
test_tupleio(tf)
|
|
||||||
test_recurrence(tf)
|
|
||||||
test_stacktrace(tf)
|
|
||||||
test_anon(tf)
|
|
||||||
|
|
||||||
@testset "Tensor interface" begin
|
|
||||||
sess = TensorFlow.Session()
|
|
||||||
X = placeholder(Float32)
|
|
||||||
Y = Flux.TF.astensor(d, X)
|
|
||||||
run(sess, global_variables_initializer())
|
|
||||||
|
|
||||||
@test run(sess, Y, Dict(X=>xs)) ≈ d(xs)
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Ops" begin
|
|
||||||
A = randn(Float32,(5,5))
|
|
||||||
# u,s,v = tf(@net x -> svd(x))(A)
|
|
||||||
# @test A ≈ u*diagm(s)*transpose(v)
|
|
||||||
@test tf(@net x -> inv(x))(A) ≈ inv(A)
|
|
||||||
@test tf(@net x -> det(x))(A) ≈ det(A)
|
|
||||||
A = randn(Float32,(6,3))
|
|
||||||
@test tf(@net x -> transpose(x))(A) ≈ transpose(A)
|
|
||||||
A = randn(Float32,(6,3,2))
|
|
||||||
@test tf(@net (x,y) -> permutedims(x,y))(A,[3,2,1]) ≈ permutedims(A,[3,2,1])
|
|
||||||
A1 = randn(Float32,(4,1))
|
|
||||||
A2 = randn(Float32,(4,1))
|
|
||||||
@test tf(@net (x,y) -> cat(2,x,y))(A1,A2) ≈ cat(2,A1,A2)
|
|
||||||
@test tf(@net x -> length(x))(A1) == length(A1)
|
|
||||||
A = randn(Float32,(5,5))
|
|
||||||
@test tf(@net x -> diag(x))(A) ≈ diag(A)
|
|
||||||
A = randn(Float32,(5,))
|
|
||||||
@test tf(@net x -> diagm(x))(A) ≈ diagm(A)
|
|
||||||
A = randn(4,5)
|
|
||||||
@test tf(@net x -> size(x))(A) == [4,5]
|
|
||||||
@test tf(@net (x,y) -> size(x,y))(A,1) == 4
|
|
||||||
A = randn(6,5)
|
|
||||||
A = A'*A
|
|
||||||
@test tf(@net x -> chol(x))(A) ≈ chol(A)
|
|
||||||
A = randn(Float32,(6,3))
|
|
||||||
@test transpose(tf(@net (x,y) -> reshape(x,y))(transpose(A),[2,9])) ≈ reshape(A,(9,2)) # Note: TF is row major and julia is not
|
|
||||||
A = randn(Float32,(4,3,1))
|
|
||||||
@test tf(@net (x,y) -> Flux.tile(x,y))(A,[1,1,3]) ≈ repeat(A,outer=(1,1,3))
|
|
||||||
@test tf(@net (x,y) -> fill(x,y))(3.2,[3,2]) ≈ convert(Array{Float32},3.2*ones(3,2))
|
|
||||||
@test typeof(tf(@net x -> Flux.cast(x,Int32))(A)) == Array{Int32,3}
|
|
||||||
A = randn(Float32,(5,5))
|
|
||||||
b = randn(Float32,(5,1))
|
|
||||||
@test tf(@net (x,y) -> solve(x,y))(A,b) ≈ A\b
|
|
||||||
_,A,_ = lu(A)
|
|
||||||
@test tf(@net (x,y) -> triangular_solve(x,y))(A,b) ≈ A\b
|
|
||||||
@test size(tf(@net x -> randu(x))([2,3])) == (2,3)
|
|
||||||
@test size(tf(@net x -> randn(x))([2,3])) == (2,3)
|
|
||||||
m = tf(@net (x,y) -> Flux.expand_dims(x,y))
|
|
||||||
A = randn(Float32,(3,2))
|
|
||||||
@test m(A,1) ≈ Flux.expand_dims(A,1)
|
|
||||||
@test m(A,2) ≈ Flux.expand_dims(A,2)
|
|
||||||
@test m(A,3) ≈ Flux.expand_dims(A,3)
|
|
||||||
end
|
|
||||||
|
|
||||||
end
|
|
@ -2,14 +2,6 @@ using Flux, DataFlow, MacroTools, Base.Test
|
|||||||
using Flux: graph, Param, squeeze, unsqueeze, back!, update!, flatten
|
using Flux: graph, Param, squeeze, unsqueeze, back!, update!, flatten
|
||||||
using DataFlow: Line, Frame
|
using DataFlow: Line, Frame
|
||||||
|
|
||||||
macro mxonly(ex)
|
|
||||||
:(Base.find_in_path("MXNet") ≠ nothing && $(esc(ex)))
|
|
||||||
end
|
|
||||||
|
|
||||||
macro tfonly(ex)
|
|
||||||
:(Base.find_in_path("TensorFlow") ≠ nothing && $(esc(ex)))
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Flux" begin
|
@testset "Flux" begin
|
||||||
|
|
||||||
include("batching.jl")
|
include("batching.jl")
|
||||||
@ -20,7 +12,4 @@ include("recurrent.jl")
|
|||||||
include("optimizer.jl")
|
include("optimizer.jl")
|
||||||
include("throttle.jl")
|
include("throttle.jl")
|
||||||
|
|
||||||
@tfonly include("backend/tensorflow.jl")
|
|
||||||
@mxonly include("backend/mxnet.jl")
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user