Float32 by default
This commit is contained in:
parent
c646ba4483
commit
bf04b70ad1
@ -1,6 +1,6 @@
|
|||||||
using Flux, MNIST
|
using Flux, MNIST
|
||||||
|
|
||||||
data = [(trainfeatures(i), Vector{Float64}(onehot(trainlabel(i), 0:9))) for i = 1:60_000]
|
data = [(trainfeatures(i), Vector{Float32}(onehot(trainlabel(i), 0:9))) for i = 1:60_000]
|
||||||
train = data[1:50_000]
|
train = data[1:50_000]
|
||||||
test = data[50_001:60_000]
|
test = data[50_001:60_000]
|
||||||
|
|
||||||
|
@ -40,8 +40,8 @@ using TensorFlow
|
|||||||
|
|
||||||
sess = Session(Graph())
|
sess = Session(Graph())
|
||||||
|
|
||||||
x = placeholder(Float64)
|
x = placeholder(Float32)
|
||||||
y′ = placeholder(Float64)
|
y′ = placeholder(Float32)
|
||||||
|
|
||||||
y = Tensor(lenet, x)
|
y = Tensor(lenet, x)
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ Media.render(::Juno.Clipboard, ::Model) = "Flux.TF.Model()"
|
|||||||
|
|
||||||
function tf(model)
|
function tf(model)
|
||||||
sess = Session(Graph())
|
sess = Session(Graph())
|
||||||
input = placeholder(Float64)
|
input = placeholder(Float32)
|
||||||
g = graph(model, input)
|
g = graph(model, input)
|
||||||
run(sess, initialize_all_variables())
|
run(sess, initialize_all_variables())
|
||||||
Model(sess, [input], g, gradients(g, input))
|
Model(sess, [input], g, gradients(g, input))
|
||||||
@ -90,7 +90,7 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
|||||||
loss = (y, y′) -> reduce_sum((y - y′).^2)/2,
|
loss = (y, y′) -> reduce_sum((y - y′).^2)/2,
|
||||||
opt = TensorFlow.train.GradientDescentOptimizer(η))
|
opt = TensorFlow.train.GradientDescentOptimizer(η))
|
||||||
i = 0
|
i = 0
|
||||||
Y = placeholder(Float64)
|
Y = placeholder(Float32)
|
||||||
Loss = loss(m.graph, Y)
|
Loss = loss(m.graph, Y)
|
||||||
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
||||||
for e in 1:epoch
|
for e in 1:epoch
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
export Conv2D, MaxPool
|
export Conv2D, MaxPool
|
||||||
|
|
||||||
type Conv2D <: Model
|
type Conv2D <: Model
|
||||||
filter::Param{Array{Float64,4}} # [height, width, outchans, inchans]
|
filter::Param{Array{Float32,4}} # [height, width, outchans, inchans]
|
||||||
stride::Dims{2}
|
stride::Dims{2}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ const AArray = AbstractArray
|
|||||||
onehot(label, labels) = [i == label for i in labels]
|
onehot(label, labels) = [i == label for i in labels]
|
||||||
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
|
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
|
||||||
|
|
||||||
initn(dims...) = randn(dims...)/1000
|
initn(dims...) = randn(Float32, dims...)/1000
|
||||||
|
|
||||||
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
||||||
i = 0
|
i = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user