Float32 by default
This commit is contained in:
parent
c646ba4483
commit
bf04b70ad1
|
@ -1,6 +1,6 @@
|
|||
using Flux, MNIST
|
||||
|
||||
data = [(trainfeatures(i), Vector{Float64}(onehot(trainlabel(i), 0:9))) for i = 1:60_000]
|
||||
data = [(trainfeatures(i), Vector{Float32}(onehot(trainlabel(i), 0:9))) for i = 1:60_000]
|
||||
train = data[1:50_000]
|
||||
test = data[50_001:60_000]
|
||||
|
||||
|
|
|
@ -40,8 +40,8 @@ using TensorFlow
|
|||
|
||||
sess = Session(Graph())
|
||||
|
||||
x = placeholder(Float64)
|
||||
y′ = placeholder(Float64)
|
||||
x = placeholder(Float32)
|
||||
y′ = placeholder(Float32)
|
||||
|
||||
y = Tensor(lenet, x)
|
||||
|
||||
|
|
|
@ -65,7 +65,7 @@ Media.render(::Juno.Clipboard, ::Model) = "Flux.TF.Model()"
|
|||
|
||||
function tf(model)
|
||||
sess = Session(Graph())
|
||||
input = placeholder(Float64)
|
||||
input = placeholder(Float32)
|
||||
g = graph(model, input)
|
||||
run(sess, initialize_all_variables())
|
||||
Model(sess, [input], g, gradients(g, input))
|
||||
|
@ -90,7 +90,7 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
|||
loss = (y, y′) -> reduce_sum((y - y′).^2)/2,
|
||||
opt = TensorFlow.train.GradientDescentOptimizer(η))
|
||||
i = 0
|
||||
Y = placeholder(Float64)
|
||||
Y = placeholder(Float32)
|
||||
Loss = loss(m.graph, Y)
|
||||
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
||||
for e in 1:epoch
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
export Conv2D, MaxPool
|
||||
|
||||
type Conv2D <: Model
|
||||
filter::Param{Array{Float64,4}} # [height, width, outchans, inchans]
|
||||
filter::Param{Array{Float32,4}} # [height, width, outchans, inchans]
|
||||
stride::Dims{2}
|
||||
end
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ const AArray = AbstractArray
|
|||
onehot(label, labels) = [i == label for i in labels]
|
||||
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
|
||||
|
||||
initn(dims...) = randn(dims...)/1000
|
||||
initn(dims...) = randn(Float32, dims...)/1000
|
||||
|
||||
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
||||
i = 0
|
||||
|
|
Loading…
Reference in New Issue