updates
This commit is contained in:
parent
409c44d362
commit
4961bf72af
|
@ -1,15 +1,11 @@
|
|||
module TF
|
||||
|
||||
using ..Flux, Flow, TensorFlow
|
||||
import Juno: info
|
||||
using ..Flux, Flow, TensorFlow, Juno
|
||||
import Flux: accuracy
|
||||
import Juno: info
|
||||
|
||||
export tf
|
||||
|
||||
# Workaround for tensor display bug
|
||||
using Juno
|
||||
Media.render(::Juno.Clipboard, ::Tensor) = "Tensor()"
|
||||
|
||||
cvalue(x) = x
|
||||
cvalue(c::Constant) = c.value
|
||||
cvalue(v::Vertex) = cvalue(value(v))
|
||||
|
@ -36,7 +32,10 @@ graph(::typeof(+), args...) = +(args...)
|
|||
graph(::typeof(softmax), x) = nn.softmax(x)
|
||||
graph(::typeof(relu), x) = nn.relu(x)
|
||||
graph(::typeof(tanh), x) = tanh(x)
|
||||
graph(::typeof(flatten), x) = reshape(x, [-1])
|
||||
|
||||
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
||||
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
||||
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x),Int32(-1)]))
|
||||
|
||||
graph(::Input, x) = x
|
||||
|
||||
|
@ -62,8 +61,6 @@ type Model
|
|||
grad::Tensor
|
||||
end
|
||||
|
||||
Media.render(::Juno.Clipboard, ::Model) = "Flux.TF.Model()"
|
||||
|
||||
function tf(model)
|
||||
sess = Session(Graph())
|
||||
input = placeholder(Float32)
|
||||
|
@ -108,4 +105,14 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
|||
end
|
||||
end
|
||||
|
||||
type Op
|
||||
f
|
||||
shape
|
||||
end
|
||||
|
||||
Op(f) = Op(f, (d...) -> nothing)
|
||||
|
||||
graph(op::Op, xs...) = op.f(xs...)
|
||||
Flux.shape(op::Op, d...) = op.shape(d...)
|
||||
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue