fix backend imports

This commit is contained in:
Mike J Innes 2017-06-05 16:32:16 +01:00
parent ddcd576a74
commit 415c5f6963
6 changed files with 11 additions and 10 deletions

View File

@ -10,7 +10,7 @@ using Base: @get!
using DataFlow: Constant, constant
using DataFlow.Interpreter
using DataFlow.Interpreter: Exception, totrace
import Flux: mapt, broadcastto,
import Flux: Reshape, MaxPool, flatten, mapt, broadcastto,
# TODO: implement Julia's type promotion rules

View File

@ -1,4 +1,4 @@
using Flux: collectt, shapecheckt
using Flux: collectt, shapecheckt, back!, update!
function copyargs!(as, bs)
for id in intersect(keys(as), keys(bs))
@ -134,7 +134,7 @@ function rewrite_softmax(model, name)
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
end
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
function FeedForward(model; input = :data, label = :softmax, context = mx.cpu())
model = rewrite_softmax(model, label)
graph = tograph(model, input, feedforward=true)
ff = mx.FeedForward(graph.output, context = context)

View File

@ -87,9 +87,8 @@ function tograph(model, args...; variables = false)
return ctx[:params], ctx[:stacks], out
end
# TODO: replace this
# TensorFlow.Tensor(m::Flux.Model, args...) =
# tograph(m, args...; variables = true)[3]
astensor(model, args...) =
tograph(model, args...; variables = true)[3]
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))

View File

@ -15,13 +15,15 @@ test_stacktrace(mxnet)
test_back(mxnet)
test_anon(mxnet)
using Flux: MaxPool
@testset "Native interface" begin
f = mx.FeedForward(Chain(d, softmax))
f = Flux.MX.FeedForward(Chain(d, softmax))
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
flatten, Affine(1587, 10), softmax)
f = mx.FeedForward(m)
f = Flux.MX.FeedForward(m)
# TODO: test run
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
end

View File

@ -17,7 +17,7 @@ test_anon(tf)
@testset "Tensor interface" begin
sess = TensorFlow.Session()
X = placeholder(Float32)
Y = Tensor(d, X)
Y = Flux.TF.astensor(d, X)
run(sess, global_variables_initializer())
@test run(sess, Y, Dict(X=>xs)) d(xs)

View File

@ -1,5 +1,5 @@
using Flux, DataFlow, MacroTools, Base.Test
using Flux: graph, Param, squeeze, unsqueeze
using Flux: graph, Param, squeeze, unsqueeze, back!, update!, flatten
using DataFlow: Line, Frame
macro mxonly(ex)