SparcityVSEfficiency/aux_func.jl

63 lines
1.8 KiB
Julia
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#using Turing, MLDatasets
# Function to get datasets
function get_data(name::String)
if name == "mnist"
train_data_mnist = MLDatasets.MNIST(;Tx=Float32, split=:train)
test_data_mnist = MLDatasets.MNIST(;Tx=Float32, split=:test)
return train_data_mnist, test_data_mnist
elseif name == "cifar"
train_data_cifar = MLDatasets.CIFAR10(;Tx=Float32, split=:train)
test_data_cifar = MLDatasets.CIFAR10(;Tx=Float32, split=:test)
return train_data_cifar, test_data_cifar
else
println("That is not a valid dataset")
end
end
function loader(data::MNIST=train_data; batchsize::Int=64)
x4dim = reshape(data.features, 28,28,1,:) # insert trivial channel dim
yhot = Flux.onehotbatch(data.targets, 0:9) # make a 10×60000 OneHotMatrix
Flux.DataLoader((x4dim, yhot); batchsize, shuffle=true) #|> gpu
end
# Create a regularization term and a Gaussian prior variance term.
alpha = 0.09;
sig = sqrt(1.0 / alpha);
# Specify the probabilistic model.
@model function bayes_nn(xs, ys, nparameters, reconstruct)
# Create the weight and bias vector.
parameters ~ MvNormal(zeros(nparameters), sig .* ones(nparameters))
# Construct NN from parameters
nn = reconstruct(parameters)
# Forward NN to make predictions
preds = []
for x in xs
push!(preds,nn(x))
end
# Observe each prediction.
for p in preds
col, row = size(p)
tempy = []
for r in 1:row
end
end
for i in 1:length(ys)
ys[i] ~ Multinomial(1,preds[i])
end
end;
# A helper to create NN from weights `theta` and run it through data `x`
nn_forward(x, theta) = reconstruct(theta)(x)
# Return the average predicted value across
# multiple weights.
function nn_predict(x, theta, num)
return mean([nn_forward(x, theta[i, :])[1] for i in 1:10:num])
end;