SparcityVSEfficiency/FNN/FNN.jl

223 lines
7.1 KiB
Julia
Raw Normal View History

2023-05-24 19:20:51 +00:00
# Classification of MNIST dataset using a convolutional network,
# which is a variant of the original LeNet from 1998.
# This example uses a GPU if you have one.
# And demonstrates how to save model state.
using MLDatasets, Flux, JLD2, CUDA # this will install everything if necc.
folder = "runs" # sub-directory in which to save
isdir(folder) || mkdir(folder)
filename = joinpath(folder, "lenet.jld2")
#===== DATA =====#
# Calling MLDatasets.MNIST() will dowload the dataset if necessary,
# and return a struct containing it.
# It takes a few seconds to read from disk each time, so do this once:
train_data = MLDatasets.MNIST() # i.e. split=:train
test_data = MLDatasets.MNIST(split=:test)
# train_data.features is a 28×28×60000 Array{Float32, 3} of the images.
# Flux needs a 4D array, with the 3rd dim for channels -- here trivial, grayscale.
# Combine the reshape needed with other pre-processing:
function loader(data::MNIST=train_data; batchsize::Int=64)
x4dim = reshape(data.features, 28,28,1,:) # insert trivial channel dim
yhot = Flux.onehotbatch(data.targets, 0:9) # make a 10×60000 OneHotMatrix
Flux.DataLoader((x4dim, yhot); batchsize, shuffle=true) |> gpu
end
loader() # returns a DataLoader, with first element a tuple like this:
x1, y1 = first(loader()); # (28×28×1×64 Array{Float32, 3}, 10×64 OneHotMatrix(::Vector{UInt32}))
# If you are using a GPU, these should be CuArray{Float32, 3} etc.
# If not, the `gpu` function does nothing (except complain the first time).
#===== MODEL =====#
# LeNet has two convolutional layers, and our modern version has relu nonlinearities.
# After each conv layer there's a pooling step. Finally, there are some fully connected layers:
lenet = Chain(
Conv((5, 5), 1=>6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6=>16, relu),
MaxPool((2, 2)),
Flux.flatten,
Dense(256 => 120, relu),
Dense(120 => 84, relu),
Dense(84 => 10),
) |> gpu
# Notice that most of the parameters are in the final Dense layers.
y1hat = lenet(x1) # try it out
sum(softmax(y1hat); dims=1)
# Each column of softmax(y1hat) may be thought of as the network's probabilities
# that an input image is in each of 10 classes. To find its most likely answer,
# we can look for the largest output in each column, without needing softmax first.
# At the moment, these don't resemble the true values at all:
@show hcat(Flux.onecold(y1hat, 0:9), Flux.onecold(y1, 0:9))
#===== METRICS =====#
# We're going to log accuracy and loss during training. There's no advantage to
# calculating these on minibatches, since MNIST is small enough to do it at once.
using Statistics: mean # standard library
function loss_and_accuracy(model, data::MNIST=test_data)
(x,y) = only(loader(data; batchsize=length(data))) # make one big batch
= model(x)
loss = Flux.logitcrossentropy(, y) # did not include softmax in the model
acc = round(100 * mean(Flux.onecold() .== Flux.onecold(y)); digits=2)
(; loss, acc, split=data.split) # return a NamedTuple
end
@show loss_and_accuracy(lenet); # accuracy about 10%, before training
#===== TRAINING =====#
# Let's collect some hyper-parameters in a NamedTuple, just to write them in one place.
# Global variables are fine -- we won't access this from inside any fast loops.
settings = (;
eta = 3e-4, # learning rate
lambda = 1e-2, # for weight decay
batchsize = 128,
epochs = 10,
)
train_log = []
# Initialise the storage needed for the optimiser:
opt_rule = OptimiserChain(WeightDecay(settings.lambda), Adam(settings.eta))
opt_state = Flux.setup(opt_rule, lenet);
for epoch in 1:settings.epochs
# @time will show a much longer time for the first epoch, due to compilation
@time for (x,y) in loader(batchsize=settings.batchsize)
grads = Flux.gradient(m -> Flux.logitcrossentropy(m(x), y), lenet)
Flux.update!(opt_state, lenet, grads[1])
end
# Logging & saving, but not on every epoch
if epoch % 2 == 1
loss, acc, _ = loss_and_accuracy(lenet)
test_loss, test_acc, _ = loss_and_accuracy(lenet, test_data)
@info "logging:" epoch acc test_acc
nt = (; epoch, loss, acc, test_loss, test_acc) # make a NamedTuple
push!(train_log, nt)
end
if epoch % 5 == 0
JLD2.jldsave(filename; lenet_state = Flux.state(lenet) |> cpu)
println("saved to ", filename, " after ", epoch, " epochs")
end
end
@show train_log;
# We can re-run the quick sanity-check of predictions:
y1hat = lenet(x1)
@show hcat(Flux.onecold(y1hat, 0:9), Flux.onecold(y1, 0:9))
#===== INSPECTION =====#
using ImageCore, ImageInTerminal
xtest, ytest = only(loader(test_data, batchsize=length(test_data)));
# There are many ways to look at images, you won't need ImageInTerminal if working in a notebook.
# ImageCore.Gray is a special type, whick interprets numbers between 0.0 and 1.0 as shades:
xtest[:,:,1,5] .|> Gray |> transpose |> cpu
Flux.onecold(ytest, 0:9)[5] # true label, should match!
# Let's look for the image whose classification is least certain.
# First, in each column of probabilities, ask for the largest one.
# Then, over all images, ask for the lowest such probability, and its index.
ptest = softmax(lenet(xtest))
max_p = maximum(ptest; dims=1)
_, i = findmin(vec(max_p))
xtest[:,:,1,i] .|> Gray |> transpose |> cpu
Flux.onecold(ytest, 0:9)[i] # true classification
ptest[:,i] # probabilities of all outcomes
Flux.onecold(ptest[:,i], 0:9) # uncertain prediction
#===== ARRAY SIZES =====#
# A layer like Conv((5, 5), 1=>6) takes 5x5 patches of an image, and matches them to each
# of 6 different 5x5 filters, placed at every possible position. These filters are here:
Conv((5, 5), 1=>6).weight |> summary # 5×5×1×6 Array{Float32, 4}
# This layer can accept any size of image; let's trace the sizes with the actual input:
#=
julia> x1 |> size
(28, 28, 1, 64)
julia> lenet[1](x1) |> size # after Conv((5, 5), 1=>6, relu),
(24, 24, 6, 64)
julia> lenet[1:2](x1) |> size # after MaxPool((2, 2))
(12, 12, 6, 64)
julia> lenet[1:3](x1) |> size # after Conv((5, 5), 6 => 16, relu)
(8, 8, 16, 64)
julia> lenet[1:4](x1) |> size # after MaxPool((2, 2))
(4, 4, 16, 64)
julia> lenet[1:5](x1) |> size # after Flux.flatten
(256, 64)
=#
# Flux.flatten is just reshape, preserving the batch dimesion (64) while combining others (4*4*16).
# This 256 must match the Dense(256 => 120). Here is how to automate this, with Flux.outputsize:
lenet2 = Flux.@autosize (28, 28, 1, 1) Chain(
Conv((5, 5), 1=>6, relu),
MaxPool((2, 2)),
Conv((5, 5), _=>16, relu),
MaxPool((2, 2)),
Flux.flatten,
Dense(_ => 120, relu),
Dense(_ => 84, relu),
Dense(_ => 10),
)
# Check that this indeed accepts input the same size as above:
@show lenet2(cpu(x1)) |> size;
#===== LOADING =====#
# During training, the code above saves the model state to disk. Load the last version:
loaded_state = JLD2.load(filename, "lenet_state");
# Now you would normally re-create the model, and copy all parameters into that.
# We can use lenet2 from just above:
Flux.loadmodel!(lenet2, loaded_state)
# Check that it now agrees with the earlier, trained, model:
@show lenet2(cpu(x1)) cpu(lenet(x1);
#===== THE END =====#