This commit is contained in:
Mike J Innes 2016-10-10 23:04:26 +01:00
parent 409c44d362
commit 4961bf72af
1 changed files with 16 additions and 9 deletions

View File

@ -1,15 +1,11 @@
module TF
using ..Flux, Flow, TensorFlow
import Juno: info
using ..Flux, Flow, TensorFlow, Juno
import Flux: accuracy
import Juno: info
export tf
# Workaround for tensor display bug
using Juno
Media.render(::Juno.Clipboard, ::Tensor) = "Tensor()"
cvalue(x) = x
cvalue(c::Constant) = c.value
cvalue(v::Vertex) = cvalue(value(v))
@ -36,7 +32,10 @@ graph(::typeof(+), args...) = +(args...)
graph(::typeof(softmax), x) = nn.softmax(x)
graph(::typeof(relu), x) = nn.relu(x)
graph(::typeof(tanh), x) = tanh(x)
graph(::typeof(flatten), x) = reshape(x, [-1])
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x),Int32(-1)]))
graph(::Input, x) = x
@ -62,8 +61,6 @@ type Model
grad::Tensor
end
Media.render(::Juno.Clipboard, ::Model) = "Flux.TF.Model()"
function tf(model)
sess = Session(Graph())
input = placeholder(Float32)
@ -108,4 +105,14 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
end
end
type Op
f
shape
end
Op(f) = Op(f, (d...) -> nothing)
graph(op::Op, xs...) = op.f(xs...)
Flux.shape(op::Op, d...) = op.shape(d...)
end