63 lines
1.8 KiB
Julia
Executable File
63 lines
1.8 KiB
Julia
Executable File
#using Turing, MLDatasets
|
||
|
||
# Function to get datasets
|
||
function get_data(name::String)
|
||
if name == "mnist"
|
||
train_data_mnist = MLDatasets.MNIST(;Tx=Float32, split=:train)
|
||
test_data_mnist = MLDatasets.MNIST(;Tx=Float32, split=:test)
|
||
return train_data_mnist, test_data_mnist
|
||
elseif name == "cifar"
|
||
train_data_cifar = MLDatasets.CIFAR10(;Tx=Float32, split=:train)
|
||
test_data_cifar = MLDatasets.CIFAR10(;Tx=Float32, split=:test)
|
||
return train_data_cifar, test_data_cifar
|
||
else
|
||
println("That is not a valid dataset")
|
||
end
|
||
end
|
||
|
||
function loader(data::MNIST=train_data; batchsize::Int=64)
|
||
x4dim = reshape(data.features, 28,28,1,:) # insert trivial channel dim
|
||
yhot = Flux.onehotbatch(data.targets, 0:9) # make a 10×60000 OneHotMatrix
|
||
Flux.DataLoader((x4dim, yhot); batchsize, shuffle=true) #|> gpu
|
||
end
|
||
|
||
# Create a regularization term and a Gaussian prior variance term.
|
||
alpha = 0.09;
|
||
sig = sqrt(1.0 / alpha);
|
||
|
||
# Specify the probabilistic model.
|
||
@model function bayes_nn(xs, ys, nparameters, reconstruct)
|
||
# Create the weight and bias vector.
|
||
parameters ~ MvNormal(zeros(nparameters), sig .* ones(nparameters))
|
||
|
||
# Construct NN from parameters
|
||
nn = reconstruct(parameters)
|
||
# Forward NN to make predictions
|
||
preds = []
|
||
for x in xs
|
||
push!(preds,nn(x))
|
||
end
|
||
|
||
# Observe each prediction.
|
||
for p in preds
|
||
col, row = size(p)
|
||
tempy = []
|
||
for r in 1:row
|
||
|
||
end
|
||
end
|
||
for i in 1:length(ys)
|
||
ys[i] ~ Multinomial(1,preds[i])
|
||
end
|
||
end;
|
||
|
||
# A helper to create NN from weights `theta` and run it through data `x`
|
||
nn_forward(x, theta) = reconstruct(theta)(x)
|
||
|
||
# Return the average predicted value across
|
||
# multiple weights.
|
||
function nn_predict(x, theta, num)
|
||
return mean([nn_forward(x, theta[i, :])[1] for i in 1:10:num])
|
||
end;
|
||
|