fix backend imports
This commit is contained in:
parent
ddcd576a74
commit
415c5f6963
@ -10,7 +10,7 @@ using Base: @get!
|
||||
using DataFlow: Constant, constant
|
||||
using DataFlow.Interpreter
|
||||
using DataFlow.Interpreter: Exception, totrace
|
||||
import Flux: mapt, broadcastto, ∘
|
||||
import Flux: Reshape, MaxPool, flatten, mapt, broadcastto, ∘
|
||||
|
||||
# TODO: implement Julia's type promotion rules
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
using Flux: collectt, shapecheckt
|
||||
using Flux: collectt, shapecheckt, back!, update!
|
||||
|
||||
function copyargs!(as, bs)
|
||||
for id in intersect(keys(as), keys(bs))
|
||||
@ -134,7 +134,7 @@ function rewrite_softmax(model, name)
|
||||
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
|
||||
end
|
||||
|
||||
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
||||
function FeedForward(model; input = :data, label = :softmax, context = mx.cpu())
|
||||
model = rewrite_softmax(model, label)
|
||||
graph = tograph(model, input, feedforward=true)
|
||||
ff = mx.FeedForward(graph.output, context = context)
|
||||
|
@ -87,9 +87,8 @@ function tograph(model, args...; variables = false)
|
||||
return ctx[:params], ctx[:stacks], out
|
||||
end
|
||||
|
||||
# TODO: replace this
|
||||
# TensorFlow.Tensor(m::Flux.Model, args...) =
|
||||
# tograph(m, args...; variables = true)[3]
|
||||
astensor(model, args...) =
|
||||
tograph(model, args...; variables = true)[3]
|
||||
|
||||
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
||||
|
||||
|
@ -15,13 +15,15 @@ test_stacktrace(mxnet)
|
||||
test_back(mxnet)
|
||||
test_anon(mxnet)
|
||||
|
||||
using Flux: MaxPool
|
||||
|
||||
@testset "Native interface" begin
|
||||
f = mx.FeedForward(Chain(d, softmax))
|
||||
f = Flux.MX.FeedForward(Chain(d, softmax))
|
||||
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
|
||||
|
||||
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
|
||||
flatten, Affine(1587, 10), softmax)
|
||||
f = mx.FeedForward(m)
|
||||
f = Flux.MX.FeedForward(m)
|
||||
# TODO: test run
|
||||
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
||||
end
|
||||
|
@ -17,7 +17,7 @@ test_anon(tf)
|
||||
@testset "Tensor interface" begin
|
||||
sess = TensorFlow.Session()
|
||||
X = placeholder(Float32)
|
||||
Y = Tensor(d, X)
|
||||
Y = Flux.TF.astensor(d, X)
|
||||
run(sess, global_variables_initializer())
|
||||
|
||||
@test run(sess, Y, Dict(X=>xs)) ≈ d(xs)
|
||||
|
@ -1,5 +1,5 @@
|
||||
using Flux, DataFlow, MacroTools, Base.Test
|
||||
using Flux: graph, Param, squeeze, unsqueeze
|
||||
using Flux: graph, Param, squeeze, unsqueeze, back!, update!, flatten
|
||||
using DataFlow: Line, Frame
|
||||
|
||||
macro mxonly(ex)
|
||||
|
Loading…
Reference in New Issue
Block a user