SparcityVSEfficiency/BNN/BNN.jl
2023-05-24 20:20:51 +01:00

164 lines
3.5 KiB
Julia
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#=
network_shape = [
((5, 5), 1=>6, relu),
((2, 2)),
((5, 5), 6=>16, relu),
((2, 2)),
Flux.flatten,
(256 => 120, relu),
(120 => 84, relu),
(84 => 10),
];
=#
#=
lenet = Chain(
Conv((5, 5), 1=>6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6=>16, relu),
MaxPool((2, 2)),
Flux.flatten,
Dense(256 => 120, relu),
Dense(120 => 84, relu),
Dense(84 => 10),
)
=#
#########################################################
# Import libraries.
using Turing, Flux, Plots, Random, ReverseDiff, MLDatasets
include("./aux_func.jl")
# Hide sampling progress.
Turing.setprogress!(false);
# Use reverse_diff due to the number of parameters in neural networks.
Turing.setadbackend(:reversediff)
train_mnist, test_mnist = get_data("mnist")
#train_cifar, test_cifar = get_data("cifar")
# Number of points to generate.
#N = 80;
#M = round(Int, N / 4);
Random.seed!(1234)
#=
# Generate artificial data.
x1s = rand(M) * 4.5;
x2s = rand(M) * 4.5;
xt1s = Array([[x1s[i] + 0.5; x2s[i] + 0.5] for i in 1:M])
x1s = rand(M) * 4.5;
x2s = rand(M) * 4.5;
append!(xt1s, Array([[x1s[i] - 5; x2s[i] - 5] for i in 1:M]))
x1s = rand(M) * 4.5;
x2s = rand(M) * 4.5;
xt0s = Array([[x1s[i] + 0.5; x2s[i] - 5] for i in 1:M])
x1s = rand(M) * 4.5;
x2s = rand(M) * 4.5;
append!(xt0s, Array([[x1s[i] - 5; x2s[i] + 0.5] for i in 1:M]))
# Store all the data for later.
xs = [xt1s; xt0s]
ts = [ones(2 * M); zeros(2 * M)]
# Plot data points.
function plot_data()
x1 = map(e -> e[1], xt1s)
y1 = map(e -> e[2], xt1s)
x2 = map(e -> e[1], xt0s)
y2 = map(e -> e[2], xt0s)
Plots.scatter(x1, y1; color="red", clim=(0, 1))
return Plots.scatter!(x2, y2; color="blue", clim=(0, 1))
end
plot_data()
=#
# Construct a neural network using Flux
lenet = Chain(
Conv((5, 5), 1=>6, relu),
MaxPool((2, 2)),
Conv((5, 5), 6=>16, relu),
MaxPool((2, 2)),
Flux.flatten,
Dense(256 => 120, relu),
Dense(120 => 84, relu),
Dense(84 => 10),
)
batches = loader(train_mnist);
xs = [];
ys = [];
for b in batches
push!(xs,b[1])
push!(ys,b[2])
end
#x1, y1 = first(loader(train_mnist)); # (28×28×1×64 Array{Float32, 3}, 10×64 OneHotMatrix(::Vector{UInt32}))
# Extract weights and a helper function to reconstruct NN from weights
parameters_initial, reconstruct = Flux.destructure(lenet);
tot_param = length(parameters_initial); # number of parameters in NN
# Perform inference.
N = 5000;
ch = sample(
bayes_nn(xs, ys, tot_params, reconstruct), HMC(0.05, 4), N
);
# Extract all weight and bias parameters.
theta = MCMCChains.group(ch, :parameters).value;
# Plot the data we have.
plot_data()
# Find the index that provided the highest log posterior in the chain.
_, i = findmax(ch[:lp]);
# Extract the max row value from i.
i = i.I[1];
# Plot the posterior distribution with a contour plot
x1_range = collect(range(-6; stop=6, length=25));
x2_range = collect(range(-6; stop=6, length=25));
Z = [nn_forward([x1, x2], theta[i, :])[1] for x1 in x1_range, x2 in x2_range];
contour!(x1_range, x2_range, Z)
# Plot the average prediction.
plot_data()
n_end = 1500;
x1_range = collect(range(-6; stop=6, length=25));
x2_range = collect(range(-6; stop=6, length=25));
Z = [nn_predict([x1, x2], theta, n_end)[1] for x1 in x1_range, x2 in x2_range];
contour!(x1_range, x2_range, Z)
# Number of iterations to plot.
n_end = 500;
anim = @gif for i in 1:n_end
plot_data()
Z = [nn_forward([x1, x2], theta[i, :])[1] for x1 in x1_range, x2 in x2_range]
contour!(x1_range, x2_range, Z; title="Iteration $i", clim=(0, 1))
end every 5