#= network_shape = [ ((5, 5), 1=>6, relu), ((2, 2)), ((5, 5), 6=>16, relu), ((2, 2)), Flux.flatten, (256 => 120, relu), (120 => 84, relu), (84 => 10), ]; =# #= lenet = Chain( Conv((5, 5), 1=>6, relu), MaxPool((2, 2)), Conv((5, 5), 6=>16, relu), MaxPool((2, 2)), Flux.flatten, Dense(256 => 120, relu), Dense(120 => 84, relu), Dense(84 => 10), ) =# ######################################################### # Import libraries. using Turing, Flux, Plots, Random, ReverseDiff, MLDatasets include("./aux_func.jl") # Hide sampling progress. Turing.setprogress!(false); # Use reverse_diff due to the number of parameters in neural networks. Turing.setadbackend(:reversediff) train_mnist, test_mnist = get_data("mnist") #train_cifar, test_cifar = get_data("cifar") # Number of points to generate. #N = 80; #M = round(Int, N / 4); Random.seed!(1234) #= # Generate artificial data. x1s = rand(M) * 4.5; x2s = rand(M) * 4.5; xt1s = Array([[x1s[i] + 0.5; x2s[i] + 0.5] for i in 1:M]) x1s = rand(M) * 4.5; x2s = rand(M) * 4.5; append!(xt1s, Array([[x1s[i] - 5; x2s[i] - 5] for i in 1:M])) x1s = rand(M) * 4.5; x2s = rand(M) * 4.5; xt0s = Array([[x1s[i] + 0.5; x2s[i] - 5] for i in 1:M]) x1s = rand(M) * 4.5; x2s = rand(M) * 4.5; append!(xt0s, Array([[x1s[i] - 5; x2s[i] + 0.5] for i in 1:M])) # Store all the data for later. xs = [xt1s; xt0s] ts = [ones(2 * M); zeros(2 * M)] # Plot data points. function plot_data() x1 = map(e -> e[1], xt1s) y1 = map(e -> e[2], xt1s) x2 = map(e -> e[1], xt0s) y2 = map(e -> e[2], xt0s) Plots.scatter(x1, y1; color="red", clim=(0, 1)) return Plots.scatter!(x2, y2; color="blue", clim=(0, 1)) end plot_data() =# # Construct a neural network using Flux lenet = Chain( Conv((5, 5), 1=>6, relu), MaxPool((2, 2)), Conv((5, 5), 6=>16, relu), MaxPool((2, 2)), Flux.flatten, Dense(256 => 120, relu), Dense(120 => 84, relu), Dense(84 => 10), ) batches = loader(train_mnist); xs = []; ys = []; for b in batches push!(xs,b[1]) push!(ys,b[2]) end #x1, y1 = first(loader(train_mnist)); # (28×28×1×64 Array{Float32, 3}, 10×64 OneHotMatrix(::Vector{UInt32})) # Extract weights and a helper function to reconstruct NN from weights parameters_initial, reconstruct = Flux.destructure(lenet); tot_param = length(parameters_initial); # number of parameters in NN # Perform inference. N = 5000; ch = sample( bayes_nn(xs, ys, tot_params, reconstruct), HMC(0.05, 4), N ); # Extract all weight and bias parameters. theta = MCMCChains.group(ch, :parameters).value; # Plot the data we have. plot_data() # Find the index that provided the highest log posterior in the chain. _, i = findmax(ch[:lp]); # Extract the max row value from i. i = i.I[1]; # Plot the posterior distribution with a contour plot x1_range = collect(range(-6; stop=6, length=25)); x2_range = collect(range(-6; stop=6, length=25)); Z = [nn_forward([x1, x2], theta[i, :])[1] for x1 in x1_range, x2 in x2_range]; contour!(x1_range, x2_range, Z) # Plot the average prediction. plot_data() n_end = 1500; x1_range = collect(range(-6; stop=6, length=25)); x2_range = collect(range(-6; stop=6, length=25)); Z = [nn_predict([x1, x2], theta, n_end)[1] for x1 in x1_range, x2 in x2_range]; contour!(x1_range, x2_range, Z) # Number of iterations to plot. n_end = 500; anim = @gif for i in 1:n_end plot_data() Z = [nn_forward([x1, x2], theta[i, :])[1] for x1 in x1_range, x2 in x2_range] contour!(x1_range, x2_range, Z; title="Iteration $i", clim=(0, 1)) end every 5