From 415c5f69634dd3d35a4e9f79424c0282348f430a Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 5 Jun 2017 16:32:16 +0100 Subject: [PATCH] fix backend imports --- src/backend/mxnet/graph.jl | 2 +- src/backend/mxnet/model.jl | 4 ++-- src/backend/tensorflow/graph.jl | 5 ++--- test/backend/mxnet.jl | 6 ++++-- test/backend/tensorflow.jl | 2 +- test/runtests.jl | 2 +- 6 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 4107b368..4e3eae01 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -10,7 +10,7 @@ using Base: @get! using DataFlow: Constant, constant using DataFlow.Interpreter using DataFlow.Interpreter: Exception, totrace -import Flux: mapt, broadcastto, ∘ +import Flux: Reshape, MaxPool, flatten, mapt, broadcastto, ∘ # TODO: implement Julia's type promotion rules diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 9fe7e8d2..15bd69d5 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -1,4 +1,4 @@ -using Flux: collectt, shapecheckt +using Flux: collectt, shapecheckt, back!, update! function copyargs!(as, bs) for id in intersect(keys(as), keys(bs)) @@ -134,7 +134,7 @@ function rewrite_softmax(model, name) return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1])) end -function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu()) +function FeedForward(model; input = :data, label = :softmax, context = mx.cpu()) model = rewrite_softmax(model, label) graph = tograph(model, input, feedforward=true) ff = mx.FeedForward(graph.output, context = context) diff --git a/src/backend/tensorflow/graph.jl b/src/backend/tensorflow/graph.jl index 365345f3..9a94b38c 100644 --- a/src/backend/tensorflow/graph.jl +++ b/src/backend/tensorflow/graph.jl @@ -87,9 +87,8 @@ function tograph(model, args...; variables = false) return ctx[:params], ctx[:stacks], out end -# TODO: replace this -# TensorFlow.Tensor(m::Flux.Model, args...) = -# tograph(m, args...; variables = true)[3] +astensor(model, args...) = + tograph(model, args...; variables = true)[3] RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data)) diff --git a/test/backend/mxnet.jl b/test/backend/mxnet.jl index c3e39806..f0aeb0b7 100644 --- a/test/backend/mxnet.jl +++ b/test/backend/mxnet.jl @@ -15,13 +15,15 @@ test_stacktrace(mxnet) test_back(mxnet) test_anon(mxnet) +using Flux: MaxPool + @testset "Native interface" begin - f = mx.FeedForward(Chain(d, softmax)) + f = Flux.MX.FeedForward(Chain(d, softmax)) @test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)] m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)), flatten, Affine(1587, 10), softmax) - f = mx.FeedForward(m) + f = Flux.MX.FeedForward(m) # TODO: test run @test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)] end diff --git a/test/backend/tensorflow.jl b/test/backend/tensorflow.jl index b727f5e1..00f45c95 100644 --- a/test/backend/tensorflow.jl +++ b/test/backend/tensorflow.jl @@ -17,7 +17,7 @@ test_anon(tf) @testset "Tensor interface" begin sess = TensorFlow.Session() X = placeholder(Float32) - Y = Tensor(d, X) + Y = Flux.TF.astensor(d, X) run(sess, global_variables_initializer()) @test run(sess, Y, Dict(X=>xs)) ≈ d(xs) diff --git a/test/runtests.jl b/test/runtests.jl index 4a86c6e5..8dd1dd8e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Flux, DataFlow, MacroTools, Base.Test -using Flux: graph, Param, squeeze, unsqueeze +using Flux: graph, Param, squeeze, unsqueeze, back!, update!, flatten using DataFlow: Line, Frame macro mxonly(ex)