fix backend imports
This commit is contained in:
parent
ddcd576a74
commit
415c5f6963
@ -10,7 +10,7 @@ using Base: @get!
|
|||||||
using DataFlow: Constant, constant
|
using DataFlow: Constant, constant
|
||||||
using DataFlow.Interpreter
|
using DataFlow.Interpreter
|
||||||
using DataFlow.Interpreter: Exception, totrace
|
using DataFlow.Interpreter: Exception, totrace
|
||||||
import Flux: mapt, broadcastto, ∘
|
import Flux: Reshape, MaxPool, flatten, mapt, broadcastto, ∘
|
||||||
|
|
||||||
# TODO: implement Julia's type promotion rules
|
# TODO: implement Julia's type promotion rules
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
using Flux: collectt, shapecheckt
|
using Flux: collectt, shapecheckt, back!, update!
|
||||||
|
|
||||||
function copyargs!(as, bs)
|
function copyargs!(as, bs)
|
||||||
for id in intersect(keys(as), keys(bs))
|
for id in intersect(keys(as), keys(bs))
|
||||||
@ -134,7 +134,7 @@ function rewrite_softmax(model, name)
|
|||||||
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
|
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
|
||||||
end
|
end
|
||||||
|
|
||||||
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
function FeedForward(model; input = :data, label = :softmax, context = mx.cpu())
|
||||||
model = rewrite_softmax(model, label)
|
model = rewrite_softmax(model, label)
|
||||||
graph = tograph(model, input, feedforward=true)
|
graph = tograph(model, input, feedforward=true)
|
||||||
ff = mx.FeedForward(graph.output, context = context)
|
ff = mx.FeedForward(graph.output, context = context)
|
||||||
|
@ -87,9 +87,8 @@ function tograph(model, args...; variables = false)
|
|||||||
return ctx[:params], ctx[:stacks], out
|
return ctx[:params], ctx[:stacks], out
|
||||||
end
|
end
|
||||||
|
|
||||||
# TODO: replace this
|
astensor(model, args...) =
|
||||||
# TensorFlow.Tensor(m::Flux.Model, args...) =
|
tograph(model, args...; variables = true)[3]
|
||||||
# tograph(m, args...; variables = true)[3]
|
|
||||||
|
|
||||||
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
RawTensor(data::Union{Batch,Seq}) = RawTensor(rawbatch(data))
|
||||||
|
|
||||||
|
@ -15,13 +15,15 @@ test_stacktrace(mxnet)
|
|||||||
test_back(mxnet)
|
test_back(mxnet)
|
||||||
test_anon(mxnet)
|
test_anon(mxnet)
|
||||||
|
|
||||||
|
using Flux: MaxPool
|
||||||
|
|
||||||
@testset "Native interface" begin
|
@testset "Native interface" begin
|
||||||
f = mx.FeedForward(Chain(d, softmax))
|
f = Flux.MX.FeedForward(Chain(d, softmax))
|
||||||
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
|
@test mx.infer_shape(f.arch, data = (20, 1))[2] == [(10, 1)]
|
||||||
|
|
||||||
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
|
m = Chain(Input(28,28), Conv2D((5,5), out = 3), MaxPool((2,2)),
|
||||||
flatten, Affine(1587, 10), softmax)
|
flatten, Affine(1587, 10), softmax)
|
||||||
f = mx.FeedForward(m)
|
f = Flux.MX.FeedForward(m)
|
||||||
# TODO: test run
|
# TODO: test run
|
||||||
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
@test mx.infer_shape(f.arch, data = (20, 20, 5, 1))[2] == [(10, 1)]
|
||||||
end
|
end
|
||||||
|
@ -17,7 +17,7 @@ test_anon(tf)
|
|||||||
@testset "Tensor interface" begin
|
@testset "Tensor interface" begin
|
||||||
sess = TensorFlow.Session()
|
sess = TensorFlow.Session()
|
||||||
X = placeholder(Float32)
|
X = placeholder(Float32)
|
||||||
Y = Tensor(d, X)
|
Y = Flux.TF.astensor(d, X)
|
||||||
run(sess, global_variables_initializer())
|
run(sess, global_variables_initializer())
|
||||||
|
|
||||||
@test run(sess, Y, Dict(X=>xs)) ≈ d(xs)
|
@test run(sess, Y, Dict(X=>xs)) ≈ d(xs)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
using Flux, DataFlow, MacroTools, Base.Test
|
using Flux, DataFlow, MacroTools, Base.Test
|
||||||
using Flux: graph, Param, squeeze, unsqueeze
|
using Flux: graph, Param, squeeze, unsqueeze, back!, update!, flatten
|
||||||
using DataFlow: Line, Frame
|
using DataFlow: Line, Frame
|
||||||
|
|
||||||
macro mxonly(ex)
|
macro mxonly(ex)
|
||||||
|
Loading…
Reference in New Issue
Block a user