updates
This commit is contained in:
parent
409c44d362
commit
4961bf72af
@ -1,15 +1,11 @@
|
|||||||
module TF
|
module TF
|
||||||
|
|
||||||
using ..Flux, Flow, TensorFlow
|
using ..Flux, Flow, TensorFlow, Juno
|
||||||
import Juno: info
|
|
||||||
import Flux: accuracy
|
import Flux: accuracy
|
||||||
|
import Juno: info
|
||||||
|
|
||||||
export tf
|
export tf
|
||||||
|
|
||||||
# Workaround for tensor display bug
|
|
||||||
using Juno
|
|
||||||
Media.render(::Juno.Clipboard, ::Tensor) = "Tensor()"
|
|
||||||
|
|
||||||
cvalue(x) = x
|
cvalue(x) = x
|
||||||
cvalue(c::Constant) = c.value
|
cvalue(c::Constant) = c.value
|
||||||
cvalue(v::Vertex) = cvalue(value(v))
|
cvalue(v::Vertex) = cvalue(value(v))
|
||||||
@ -36,7 +32,10 @@ graph(::typeof(+), args...) = +(args...)
|
|||||||
graph(::typeof(softmax), x) = nn.softmax(x)
|
graph(::typeof(softmax), x) = nn.softmax(x)
|
||||||
graph(::typeof(relu), x) = nn.relu(x)
|
graph(::typeof(relu), x) = nn.relu(x)
|
||||||
graph(::typeof(tanh), x) = tanh(x)
|
graph(::typeof(tanh), x) = tanh(x)
|
||||||
graph(::typeof(flatten), x) = reshape(x, [-1])
|
|
||||||
|
# reshape hack due to https://github.com/malmaud/TensorFlow.jl/issues/79
|
||||||
|
batchsize(x::Tensor) = reduce_sum(slice(TensorFlow.shape(x), [0], [1]))
|
||||||
|
graph(::typeof(flatten), x) = reshape(x, pack([batchsize(x),Int32(-1)]))
|
||||||
|
|
||||||
graph(::Input, x) = x
|
graph(::Input, x) = x
|
||||||
|
|
||||||
@ -62,8 +61,6 @@ type Model
|
|||||||
grad::Tensor
|
grad::Tensor
|
||||||
end
|
end
|
||||||
|
|
||||||
Media.render(::Juno.Clipboard, ::Model) = "Flux.TF.Model()"
|
|
||||||
|
|
||||||
function tf(model)
|
function tf(model)
|
||||||
sess = Session(Graph())
|
sess = Session(Graph())
|
||||||
input = placeholder(Float32)
|
input = placeholder(Float32)
|
||||||
@ -108,4 +105,14 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
type Op
|
||||||
|
f
|
||||||
|
shape
|
||||||
|
end
|
||||||
|
|
||||||
|
Op(f) = Op(f, (d...) -> nothing)
|
||||||
|
|
||||||
|
graph(op::Op, xs...) = op.f(xs...)
|
||||||
|
Flux.shape(op::Op, d...) = op.shape(d...)
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user