remove backends

This commit is contained in:
Mike J Innes 2017-08-18 00:44:22 +01:00
parent 4604e9f515
commit 536949891d
12 changed files with 0 additions and 741 deletions

View File

@ -38,8 +38,6 @@ include("layers/cost.jl")
include("layers/recurrent.jl") include("layers/recurrent.jl")
include("layers/shims.jl") include("layers/shims.jl")
include("backend/backend.jl")
include("data.jl") include("data.jl")
include("training.jl") include("training.jl")

View File

@ -1,28 +0,0 @@
# We use a lazy-loading trick to load the backend code as needed; this avoids
# the need for a hard dependency on both backends.
# This is effectively equivalent to:
# include("tensorflow/tensorflow.jl")
# using .TF
# export tf
# but instead of loading immediately, we wait until `tf` is first called.
function loadtf()
isdefined(Flux, :TF) && return
@eval include(joinpath(dirname($@__FILE__), "tensorflow/tensorflow.jl"))
end
function tf(args...)
loadtf()
eval(:(TF.tf($(QuoteNode.(args)...))))
end
function loadmx()
isdefined(Flux, :MX) && return
@eval include(joinpath(dirname($@__FILE__), "mxnet/mxnet.jl"))
end
function mxnet(args...)
loadmx()
eval(:(MX.mxnet($(QuoteNode.(args)...))))
end

View File

@ -1,142 +0,0 @@
function nodename(s::mx.SymbolicNode)
name = Ref{mx.char_p}(0)
success = Ref(0)
mx.@mxcall(:MXSymbolGetName, (mx.MX_handle, Ref{mx.char_p}, Ref{Int}), s.handle.value, name, success)
@assert success[] != -1
return Symbol(unsafe_string(name[]))
end
using Base: @get!
using DataFlow: Constant, constant
using DataFlow.Interpreter
using DataFlow.Interpreter: Exception, totrace
import Flux: Reshape, MaxPool, flatten, mapt, broadcastto,
# TODO: implement Julia's type promotion rules
node(x::Tuple) = map(node, x)
node(x::mx.SymbolicNode) = x
graph(::typeof(tuple), args...) = (args...,)
graph(::typeof(identity), x) = x
graph(::typeof(*), xs...) = mx.dot(reverse(xs)...) # Work around MXNet shape hack
graph(::typeof(σ), x) = mx.Activation(x, act_type = :sigmoid)
graph(::typeof(relu), x) = mx.Activation(x, act_type = :relu)
graph(::typeof(tanh), x) = mx.Activation(x, act_type = :tanh)
graph(::typeof(flatten), x) = mx.Flatten(x)
graph(::typeof(hcat), xs...) = mx.concat(xs..., dim = 2-1)
graph(::typeof(vec), xs) = reshape(xs, shape = (-1,))
graph(::typeof(broadcast), ::typeof(+), args...) = mx.broadcast_plus(args...)
graph(::typeof(broadcast), ::typeof(-), args...) = mx.broadcast_sub(args...)
graph(::typeof(broadcast), ::typeof(*), args...) = mx.broadcast_mul(args...)
graph(::typeof(broadcast), ::typeof(/), args...) = mx.broadcast_div(args...)
graph(::typeof(broadcastto), xs, shape) = mx.broadcast_to(xs, shape = map(i -> i≤1?0:i, reverse(shape)))
# Old broadcasters
graph(::typeof(broadcast), ::typeof(exp), xs) = exp(xs)
graph(::typeof(.+), args...) = mx.broadcast_plus(args...)
graph(::typeof(.*), args...) = mx.broadcast_mul(args...)
graph(::typeof(.-), args...) = mx.broadcast_sub(args...)
graph(::typeof(softmax), xs) =
mx.broadcast_div(exp(xs), mx.sum(exp(xs), axis = 1, keepdims=true))
graph(::typeof(cat), dim::Integer, a...) = mx.Concat(a..., dim = dim)
graph(::typeof(vcat), a...) = graph(cat, 1, a...)
graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
graph(::typeof(sum), xs::Tuple) = reduce((a, b) -> graph(broadcast, +, a, b), xs)
graph(::typeof(Base.Iterators.repeated), x, n) = ntuple(_ -> x, n)
a::mx.SymbolicNode b::mx.SymbolicNode = mx.broadcast_mul(a, b)
graph(::Input, x) = x
struct AlterParam
param
load
store
end
Base.size(p::AlterParam) = size(p.load(p.param.x))
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
graph(ctx::Context, d::Affine, x) =
!ctx[:feedforward] ? invoke(graph, Tuple{Context, Any, typeof(x)}, ctx, d, x) :
register(ctx,
mx.FullyConnected(mx.SymbolicNode, data = x,
num_hidden = size(d.W.x, 2),
weight = var(ctx, AlterParam(d.W, x->x', nothing)),
bias = var(ctx, AlterParam(d.b, x->squeeze(x, 1), nothing))))
# TODO: use actual params
graph(ctx::Context, c::Conv2D, x) =
mx.Convolution(x,
kernel = size(c.filter, 1, 2),
num_filter = size(c.filter, 4),
stride = c.stride)
graph(ctx::Context, p::MaxPool, x) =
mx.Pooling(x,
pool_type = :max,
kernel = p.size,
stride = p.stride)
function register(ctx::Context, node::mx.SymbolicNode)
ctx[:stacks][nodename(node)] = stack(ctx)
return node
end
register(ctx::Context, node) = node
function var(ctx::Context, p::Union{Flux.Param{<:AbstractArray},AbstractArray,AlterParam})
haskey(ctx[:params], p) && return ctx[:params][p]
ctx[:params][p] = mx.Variable(gensym())
end
var(ctx::Context, x) = x
function graph(ctx::Context, model, args...)
args = var.(ctx, args)
g = Flux.graph(model)
g == nothing && return register(ctx, @icatch ctx graph(model, args...))
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
interpret(ctx, g, args...)
end
graph(ctx::Context, args...) = @icatch ctx graph(ctx, args...)
function tograph(model, args...; feedforward = false)
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, graph),
params = ObjectIdDict(), stacks = Dict(),
feedforward = feedforward)
out = @ithrow graph(ctx, model, mapt(mx.Variable, args)...)
params = Dict(nodename(v) => p for (p, v) in ctx[:params])
return Graph(args, out, params, ctx[:stacks])
end
# Error Handling
using Juno
using MacroTools: @q
Juno.errmsg(e::mx.MXError) = e.msg
function errnode(e::mx.MXError)
m = match(r"Error in operator (\w+)", e.msg)
m == nothing && return
Symbol(m.captures[1])
end
striptrace(e::mx.MXError) = mx.MXError(split(e.msg, "\n")[1])
macro mxerr(stk, ex)
@q try
$(esc(ex))
catch e
(e isa mx.MXError && (node = errnode(e)) != nothing) || rethrow()
stk = $(esc(stk))
haskey(stk, node) || rethrow()
throw(Exception(striptrace(e), totrace(stk[node])))
end
end

View File

@ -1,159 +0,0 @@
using Flux: collectt, shapecheckt, back!, update!
function copyargs!(as, bs)
for id in intersect(keys(as), keys(bs))
copy!(as[id], bs[id])
end
end
struct Graph
input
output
params::Dict{Symbol,Any}
stacks::Dict{Any,Any}
end
function mxparams(ps, ctx)
params = Dict{Symbol,MXArray}()
for (name, param) in ps
params[name] = MXArray(size(param), ctx)
end
return params
end
ndparams(d) = Dict{Symbol,mx.NDArray}(k => v.data for (k, v) in d)
struct Exec
graph::Graph
ctx::mx.Context
exec::mx.Executor
args::Dict{Symbol,MXArray}
grads::Dict{Symbol,MXArray}
outs::Vector{MXArray}
end
loadparams!(exec::Exec) = copyargs!(exec.args, exec.graph.params)
storeparams!(exec::Exec) = copyargs!(exec.graph.params, exec.args)
mxgroup(x) = x
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
mxungroup(x, outs) = copy(shift!(outs))
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
function executor(graph::Graph, input...; ctx = mx.cpu())
shapecheckt(graph.input, input)
args = merge(mxparams(graph.params, ctx), dictt(graph.input, mapt(d->MXArray(size(d), ctx), input)))
grads = filter((a, b) -> b isa Flux.Param, graph.params)
grads = merge(mxparams(grads, ctx), dictt(graph.input, mapt(d->MXArray(size(d), ctx), input)))
exec = mx.bind(mxgroup(graph.output),
context = ctx,
args = ndparams(args),
args_grad = ndparams(grads),
grad_req = mx.GRAD_ADD)
exec = Exec(graph, ctx, exec, args, grads, MXArray.(exec.outputs))
loadparams!(exec)
return exec
end
function (exec::Exec)(input...)
foreach(kv -> copy!(exec.args[kv[1]], kv[2]), dictt(exec.graph.input, input))
mx.forward(exec.exec, is_train = true)
mxungroup(exec.graph.output, copy(exec.outs))
end
function Flux.back!(exec::Exec, Δ)
mapt(k -> exec.grads[k][:] = 0, exec.graph.input)
mx.backward(exec.exec, map(x -> MXArray(x, exec.ctx).data, collectt(Δ)))
mapt(k -> copy(exec.grads[k]), exec.graph.input)
end
function Flux.update!(exec::Exec, η)
for (arg, grad) in zip(exec.exec.arg_arrays, exec.exec.grad_arrays)
grad == nothing && continue
mx.@nd_as_jl rw = (arg, grad) begin
arg .-= grad .* η
grad[:] = 0
end
end
storeparams!(exec)
return exec
end
toctx(ctx::mx.Context) = ctx
toctx(c::Symbol) = c == :gpu ? mx.gpu() : mx.cpu()
# TODO: if `last` changes, update params appropriately
mutable struct Model
model::Any
ctx::mx.Context
execs::Dict{Tuple,Exec}
graph::Graph
last::Exec
Model(model, ctx) = new(model, ctx, Dict())
end
mxnet(model, ctx = :cpu) = Model(model, toctx(ctx))
function Base.show(io::IO, m::Model)
print(io, "MX.Model(")
show(io, m.model)
print(io, ", ")
show(io, m.ctx)
print(io, ")")
end
import Base: @get!
# TODO: dims having its own type would be useful
executor(m::Model, input...) =
@get!(m.execs, mapt(size, input),
executor(m.graph, input...; ctx = m.ctx))
function (m::Model)(xs...)
@mxerr m.graph.stacks begin
!isdefined(m, :graph) &&
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...))
m.last = exec = executor(m, xs...)
exec(xs...)
end
end
function Flux.back!(m::Model, Δ, xs...)
m.last = exec = m.execs[mapt(size, xs)]
back!(exec, Δ)
end
Flux.update!(m::Model, η) = (update!(m.last, η); m)
# Recurrent Models
using Flux: Stateful, SeqModel
mxnet(m::Stateful, a...) = Stateful(mxnet(m.model, a...), m.states, m.istate, m.ostate)
mxnet(m::SeqModel, a...) = SeqModel(mxnet(m.model, a...), m.steps)
# MX FeedForward interface
struct SoftmaxOutput
name::Symbol
end
graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(xs, name = s.name)
function rewrite_softmax(model, name)
model == softmax && return SoftmaxOutput(name)
g = Flux.graph(model)
(g == nothing || g.value softmax || DataFlow.nin(g) 1) && error("mx.FeedForward models must end with `softmax`")
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
end
function FeedForward(model; input = :data, label = :softmax, ctx = mx.cpu())
model = rewrite_softmax(model, label)
graph = tograph(model, input, feedforward=true)
ff = mx.FeedForward(graph.output, context = ctx)
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params, ctx)))
return ff
end

View File

@ -1,40 +0,0 @@
using MXNet
# NDArray is row-major so by default all dimensions are reversed in MXNet.
# MXArray tranposes when loading/storing to fix this.
reversedims!(dest, xs) = permutedims!(dest, xs, ndims(xs):-1:1)
struct MXArray{N}
data::mx.NDArray
scratch::Array{Float32,N}
end
MXArray(data::mx.NDArray) = MXArray(data, Array{Float32}(size(data)))
# TODO: split cpu/gpu mxarrays
MXArray(dims::Dims, ctx = mx.cpu()) = MXArray(mx.zeros(reverse(dims), ctx))
Base.size(xs::MXArray) = reverse(size(xs.data))
function Base.copy!(mx::MXArray, xs::AbstractArray)
@assert size(mx) == size(xs)
reversedims!(mx.scratch, xs)
copy!(mx.data, mx.scratch)
return mx
end
function Base.copy!(xs::AbstractArray, mx::MXArray)
@assert size(xs) == size(mx)
copy!(mx.scratch, mx.data)
reversedims!(xs, mx.scratch)
end
Base.copy(mx::MXArray) = copy!(Array{Float32}(size(mx)), mx)
function MXArray(xs::AbstractArray, ctx = mx.cpu())
mx = MXArray(size(xs), ctx)
copy!(mx, xs)
end
Base.setindex!(xs::MXArray, x::Real, ::Colon) = xs.data[:] = x

View File

@ -1,11 +0,0 @@
module MX
using MXNet, DataFlow, ..Flux
export mxnet
include("mxarray.jl")
include("graph.jl")
include("model.jl")
end

View File

@ -1,133 +0,0 @@
using Base: @get!
using Flux: Reshape, MaxPool, flatten
using DataFlow: constant, Split
using DataFlow.Interpreter
using DataFlow.Interpreter: stack
using TensorFlow: RawTensor, TFException
# TODO: implement Julia's type promotion rules
node(x::Tuple) = map(node, x)
node(x::Tensor) = x
node(x::Variable) = x
node(x::Number) = TensorFlow.constant(Float32(x))
graph(::typeof(tuple), args...) = (args...,)
graph(s::Split, t::Tuple) = t[s.n]
graph(::typeof(getindex), t::Tuple, n::Integer) = t[n]
graph(::typeof(identity), x) = TensorFlow.identity(x)
graph(::typeof(softmax), x) = nn.softmax(x)
graph(::typeof(relu), x) = nn.relu(x)
graph(::typeof(σ), x) = nn.sigmoid(x)
graph(::typeof(hcat), xs...) = concat(1, xs)
graph(::typeof(sum), x, dim=nothing) = TensorFlow.reduce_sum(x;axis=dim)
graph(::typeof(prod), x, dim=nothing) = TensorFlow.reduce_prod(x;axis=dim)
graph(::typeof(min), x, dim=nothing) = TensorFlow.reduce_min(x;axis=dim)
graph(::typeof(max), x, dim=nothing) = TensorFlow.reduce_max(x;axis=dim)
graph(::typeof(all), x, dim=nothing) = TensorFlow.reduce_all(x;axis=dim)
graph(::typeof(any), x, dim=nothing) = TensorFlow.reduce_any(x;axis=dim)
graph(::typeof(mean), x, dim=nothing) = TensorFlow.reduce_mean(x;axis=dim)
graph(::typeof(svd), x) = svd(x)
graph(::typeof(size), x, dim) = TensorFlow.size(x,convert(Tensor{Int32}, dim))
graph(::typeof(size), x) = TensorFlow.size(x)
graph(::typeof(chol), args...) = TensorFlow.transpose(TensorFlow.cholesky(args...))
graph(::typeof(reshape), x, dims) = TensorFlow.reshape(x,convert(Tensor{Int32},dims))
graph(::typeof(Flux.tile), args...) = TensorFlow.tile(args...)
graph(::typeof(fill), x, dims) = Ops.fill(convert(Tensor{Int32}, dims), Tensor(x))
graph(::typeof(Flux.cast), args...) = TensorFlow.cast(args...)
graph(::typeof(solve), A, b) = TensorFlow.matrix_solve(A, b)
graph(::typeof(triangular_solve), A, b) = TensorFlow.matrix_triangular_solve(A, b; lower=false)
graph(::typeof(randu), x) = Ops.random_uniform(convert(Tensor{Int32},x);dtype=Float32)
graph(::typeof(randn), x) = TensorFlow.random_normal(convert(Tensor{Int32},x);dtype=Float32)
graph(::typeof(Flux.expand_dims), x, dim) = TensorFlow.expand_dims(x,convert(Tensor{Int32},dim))
for op in (*, .*, .+, .^, log, exp, ceil, floor, sqrt, abs, cos,
sin, tan, atan, asin, acos, tanh, lgamma, erf, erfc, real, imag, conj,
inv, det, transpose, permutedims, cat, length, diag, diagm)
@eval graph(::typeof($op), args...) = $op(args...)
end
for op in (+, -, *, /)
@eval graph(::typeof(broadcast), ::typeof($op), args...) = broadcast($op, args...)
end
graph(::typeof(.-), args...) = -(args...)
graph(::typeof(map), f, xss::Tuple...) = map(f, xss...)
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x), Int32(-1)]))
graph(r::Reshape, x) = reshape(x, pack([batchsize(x), map(Int32, r.dims)...]))
graph(::Input, x) = x
graph(p::MaxPool, x) =
nn.max_pool(x, [1, p.size..., 1], [1, p.stride..., 1], "VALID")
graph(op::Op, xs...) = op.f(xs...)
function graph(ctx::Context, model, args...)
node = graph(model, args...)
node isa Tensor && (ctx[:stacks][node.op.name] = stack(ctx))
return node
end
interp(ctx, c::Conv2D, x) =
nn.conv2d(x, interp(ctx, constant(c.filter)), [1,c.stride...,1], "VALID")
param(ctx, p::Flux.Param{<:AbstractArray}) =
haskey(ctx[:params], p) ?
ctx[:params][p] :
(ctx[:params][p] =
ctx[:variables] ?
Variable(Float32.(p.x)) :
placeholder(Float32))
param(ctx, x) = x
function interp(ctx, model, args...)
args = param.(ctx, args)
g = Flux.graph(model)
g == nothing && return graph(ctx, model, args...)
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
interpret(ctx, g, args...)
end
function tograph(model, args...; variables = false)
ctx = Context(mux(iline, iconst, ilambda, iargs, ituple, interp),
params = ObjectIdDict(), stacks = Dict(), variables = variables)
out = interp(ctx, model, map(constant, args)...)
return ctx[:params], ctx[:stacks], out
end
astensor(model, args...) =
tograph(model, args...; variables = true)[3]
RawTensor(data::Union{Flux.Batch,Flux.Seq}) = RawTensor(Flux.rawbatch(data))
# Error Handling
using Juno
using MacroTools: @q
using DataFlow.Interpreter: Exception, totrace
Juno.errmsg(e::TFException) = string(e.status)
function errnode(e::TFException)
m = match(r"Node: ([\w\d]+) =", string(e.status))
m == nothing && return
m.captures[1]
end
errnode(e) = nothing
macro tferr(stk, ex)
@q try
$(esc(ex))
catch e
(node = errnode(e)) != nothing || rethrow()
stk = $(esc(stk))
haskey(stk, node) || rethrow()
throw(Exception(e, totrace(stk[node])))
end
end

View File

@ -1,86 +0,0 @@
using Flux: Param, mapt, collectt, shapecheckt
struct Exec
session ::Session
input ::Any
output ::Any
params ::Dict{Param,Param{Tensor}}
stacks ::Dict{Any,Any}
end
dummy(x::Void) = TensorFlow.constant(0)
dummy(x::Tensor) = x
function makesession(model, inputs; session = Session(Graph()))
inputs = mapt(_ -> placeholder(Float32), inputs)
params, stacks, output = tograph(model, inputs...)
output = mapt(x->Param{Tensor}(x, placeholder(Float32)), output)
params = Dict(x=>Param{Tensor}(y, dummy(gradients(map(x->x.x, collectt(output)),
y, map(x->x.Δx, collectt(output)))))
for (x, y) in params)
inputs = mapt(x->Param{Tensor}(x, dummy(gradients(map(x->x.x, collectt(output)),
x, map(x->x.Δx, collectt(output))))),
inputs)
run(session, global_variables_initializer())
Exec(session, inputs, output, params, stacks)
end
retuple(xs) = xs
retuple(xs::AbstractArray{<:AbstractArray}) = (retuple.(xs)...,)
dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
function (m::Exec)(args...)
dict = merge(
Dict(y.x=>x.x for (x, y) in m.params),
Dict(x.x=>y for (x, y) in dictt(m.input, args))
)
retuple(run(m.session, mapt(x->x.x, m.output), dict))
end
function Flux.back!(m::Exec, Δ, args...)
dict = merge(
Dict(y.x=>x.x for (x, y) in m.params),
Dict(x.x=>y for (x, y) in zip(m.input, args)),
Dict(x.Δx=>y for (x, y) in zip(collectt(m.output), collectt(Δ)))
)
Δin, Δps = run(m.session, (mapt(x->x.Δx, m.input), map(x->x.Δx, values(m.params))), dict)
for (p, Δ) in zip(keys(m.params), Δps)
p.Δx .+= Δ
end
Δin
end
function Flux.update!(m::Exec, η)
for p in keys(m.params)
Flux.update!(p, η)
end
return m
end
mutable struct Model
model::Any
exec::Exec
Model(model) = new(model)
end
tf(model) = Model(model)
function (m::Model)(args...)
args = mapt(x->Float32.(x), args)
isdefined(m, :exec) || (m.exec = makesession(m.model, args))
@tferr m.exec.stacks m.exec(args...)
end
Flux.back!(m::Model, Δ, args...) = Flux.back!(m.exec, Δ, args...)
Flux.update!(m::Model, η) = (Flux.update!(m.exec, η); m)
# Recurrent Models
using Flux: Stateful, SeqModel
tf(m::Stateful) = Stateful(tf(m.model), m.states, m.istate, m.ostate)
tf(m::SeqModel) = SeqModel(tf(m.model), m.steps)

View File

@ -1,20 +0,0 @@
module TF
using ..Flux, DataFlow, TensorFlow, Juno
import Flux: accuracy, convertel
export tf
struct Op
f
shape
end
Op(f) = Op(f, (d...) -> nothing)
Flux.shape(op::Op, d...) = op.shape(d...)
include("graph.jl")
include("model.jl")
end

View File

@ -1,39 +0,0 @@
using MXNet
Flux.loadmx()
@testset "MXNet" begin
xs, ys = rand(1, 20), rand(1, 20)
d = Affine(20, 10)
dm = mxnet(d)
@test d(xs) dm(xs)
test_tupleio(mxnet)
test_recurrence(mxnet)
test_stacktrace(mxnet)
test_back(mxnet)
test_anon(mxnet)
using Flux: MaxPool
@testset "Native interface" begin
f = Flux.MX.FeedForward(Chain(d, softmax))
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
flatten, Affine(1587, 10), softmax)
f = Flux.MX.FeedForward(m)
# TODO: test run
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
end
@testset "Duplicate parameters" begin
a = Affine(10, 10)
d = Chain(a, a)
m = mxnet(d)
m(randn(1, 10))
@test length(m.graph.params) == 2
end
end

View File

@ -1,70 +0,0 @@
using TensorFlow
Flux.loadtf()
@testset "TensorFlow" begin
xs, ys = rand(1, 20), rand(1, 20)
d = Affine(20, 10)
dt = tf(d)
@test d(xs) dt(xs)
test_tupleio(tf)
test_recurrence(tf)
test_stacktrace(tf)
test_anon(tf)
@testset "Tensor interface" begin
sess = TensorFlow.Session()
X = placeholder(Float32)
Y = Flux.TF.astensor(d, X)
run(sess, global_variables_initializer())
@test run(sess, Y, Dict(X=>xs)) d(xs)
end
@testset "Ops" begin
A = randn(Float32,(5,5))
# u,s,v = tf(@net x -> svd(x))(A)
# @test A ≈ u*diagm(s)*transpose(v)
@test tf(@net x -> inv(x))(A) inv(A)
@test tf(@net x -> det(x))(A) det(A)
A = randn(Float32,(6,3))
@test tf(@net x -> transpose(x))(A) transpose(A)
A = randn(Float32,(6,3,2))
@test tf(@net (x,y) -> permutedims(x,y))(A,[3,2,1]) permutedims(A,[3,2,1])
A1 = randn(Float32,(4,1))
A2 = randn(Float32,(4,1))
@test tf(@net (x,y) -> cat(2,x,y))(A1,A2) cat(2,A1,A2)
@test tf(@net x -> length(x))(A1) == length(A1)
A = randn(Float32,(5,5))
@test tf(@net x -> diag(x))(A) diag(A)
A = randn(Float32,(5,))
@test tf(@net x -> diagm(x))(A) diagm(A)
A = randn(4,5)
@test tf(@net x -> size(x))(A) == [4,5]
@test tf(@net (x,y) -> size(x,y))(A,1) == 4
A = randn(6,5)
A = A'*A
@test tf(@net x -> chol(x))(A) chol(A)
A = randn(Float32,(6,3))
@test transpose(tf(@net (x,y) -> reshape(x,y))(transpose(A),[2,9])) reshape(A,(9,2)) # Note: TF is row major and julia is not
A = randn(Float32,(4,3,1))
@test tf(@net (x,y) -> Flux.tile(x,y))(A,[1,1,3]) repeat(A,outer=(1,1,3))
@test tf(@net (x,y) -> fill(x,y))(3.2,[3,2]) convert(Array{Float32},3.2*ones(3,2))
@test typeof(tf(@net x -> Flux.cast(x,Int32))(A)) == Array{Int32,3}
A = randn(Float32,(5,5))
b = randn(Float32,(5,1))
@test tf(@net (x,y) -> solve(x,y))(A,b) A\b
_,A,_ = lu(A)
@test tf(@net (x,y) -> triangular_solve(x,y))(A,b) A\b
@test size(tf(@net x -> randu(x))([2,3])) == (2,3)
@test size(tf(@net x -> randn(x))([2,3])) == (2,3)
m = tf(@net (x,y) -> Flux.expand_dims(x,y))
A = randn(Float32,(3,2))
@test m(A,1) Flux.expand_dims(A,1)
@test m(A,2) Flux.expand_dims(A,2)
@test m(A,3) Flux.expand_dims(A,3)
end
end

View File

@ -2,14 +2,6 @@ using Flux, DataFlow, MacroTools, Base.Test
using Flux: graph, Param, squeeze, unsqueeze, back!, update!, flatten using Flux: graph, Param, squeeze, unsqueeze, back!, update!, flatten
using DataFlow: Line, Frame using DataFlow: Line, Frame
macro mxonly(ex)
:(Base.find_in_path("MXNet") nothing && $(esc(ex)))
end
macro tfonly(ex)
:(Base.find_in_path("TensorFlow") nothing && $(esc(ex)))
end
@testset "Flux" begin @testset "Flux" begin
include("batching.jl") include("batching.jl")
@ -20,7 +12,4 @@ include("recurrent.jl")
include("optimizer.jl") include("optimizer.jl")
include("throttle.jl") include("throttle.jl")
@tfonly include("backend/tensorflow.jl")
@mxonly include("backend/mxnet.jl")
end end