bayesiancnn/uncertainty_estimation.py

185 lines
6.9 KiB
Python
Raw Normal View History

2024-05-10 09:59:24 +00:00
import argparse
import torch
import numpy as np
import pandas as pd
import seaborn as sns
from PIL import Image
import torchvision
from torch.nn import functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import data
from main_bayesian import getModel
import config_bayesian as cfg
# CUDA settings
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mnist_set = None
notmnist_set = None
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((32, 32)),
transforms.ToTensor(),
])
def init_dataset(notmnist_dir):
global mnist_set
global notmnist_set
mnist_set, _, _, _ = data.getDataset('MNIST')
notmnist_set = torchvision.datasets.ImageFolder(root=notmnist_dir)
def get_uncertainty_per_image(model, input_image, T=15, normalized=False):
input_image = input_image.unsqueeze(0)
input_images = input_image.repeat(T, 1, 1, 1)
net_out, _ = model(input_images)
pred = torch.mean(net_out, dim=0).cpu().detach().numpy()
if normalized:
prediction = F.softplus(net_out)
p_hat = prediction / torch.sum(prediction, dim=1).unsqueeze(1)
else:
p_hat = F.softmax(net_out, dim=1)
p_hat = p_hat.detach().cpu().numpy()
p_bar = np.mean(p_hat, axis=0)
temp = p_hat - np.expand_dims(p_bar, 0)
epistemic = np.dot(temp.T, temp) / T
epistemic = np.diag(epistemic)
aleatoric = np.diag(p_bar) - (np.dot(p_hat.T, p_hat) / T)
aleatoric = np.diag(aleatoric)
return pred, epistemic, aleatoric
def get_uncertainty_per_batch(model, batch, T=15, normalized=False):
batch_predictions = []
net_outs = []
batches = batch.unsqueeze(0).repeat(T, 1, 1, 1, 1)
preds = []
epistemics = []
aleatorics = []
for i in range(T): # for T batches
net_out, _ = model(batches[i].cuda())
net_outs.append(net_out)
if normalized:
prediction = F.softplus(net_out)
prediction = prediction / torch.sum(prediction, dim=1).unsqueeze(1)
else:
prediction = F.softmax(net_out, dim=1)
batch_predictions.append(prediction)
for sample in range(batch.shape[0]):
# for each sample in a batch
pred = torch.cat([a_batch[sample].unsqueeze(0) for a_batch in net_outs], dim=0)
pred = torch.mean(pred, dim=0)
preds.append(pred)
p_hat = torch.cat([a_batch[sample].unsqueeze(0) for a_batch in batch_predictions], dim=0).detach().cpu().numpy()
p_bar = np.mean(p_hat, axis=0)
temp = p_hat - np.expand_dims(p_bar, 0)
epistemic = np.dot(temp.T, temp) / T
epistemic = np.diag(epistemic)
epistemics.append(epistemic)
aleatoric = np.diag(p_bar) - (np.dot(p_hat.T, p_hat) / T)
aleatoric = np.diag(aleatoric)
aleatorics.append(aleatoric)
epistemic = np.vstack(epistemics) # (batch_size, categories)
aleatoric = np.vstack(aleatorics) # (batch_size, categories)
preds = torch.cat([i.unsqueeze(0) for i in preds]).cpu().detach().numpy() # (batch_size, categories)
return preds, epistemic, aleatoric
def get_sample(dataset, sample_type='mnist'):
idx = np.random.randint(len(dataset.targets))
if sample_type=='mnist':
sample = dataset.data[idx]
truth = dataset.targets[idx]
else:
path, truth = dataset.samples[idx]
sample = torch.from_numpy(np.array(Image.open(path)))
sample = sample.unsqueeze(0)
sample = transform(sample)
return sample.to(device), truth
def run(net_type, weight_path, notmnist_dir):
init_dataset(notmnist_dir)
layer_type = cfg.layer_type
activation_type = cfg.activation_type
net = getModel(net_type, 1, 10, priors=None, layer_type=layer_type, activation_type=activation_type)
net.load_state_dict(torch.load(weight_path))
net.train()
net.to(device)
fig = plt.figure()
fig.suptitle('Uncertainty Estimation', fontsize='x-large')
mnist_img = fig.add_subplot(321)
notmnist_img = fig.add_subplot(322)
epi_stats_norm = fig.add_subplot(323)
ale_stats_norm = fig.add_subplot(324)
epi_stats_soft = fig.add_subplot(325)
ale_stats_soft = fig.add_subplot(326)
sample_mnist, truth_mnist = get_sample(mnist_set)
pred_mnist, epi_mnist_norm, ale_mnist_norm = get_uncertainty_per_image(net, sample_mnist, T=25, normalized=True)
pred_mnist, epi_mnist_soft, ale_mnist_soft = get_uncertainty_per_image(net, sample_mnist, T=25, normalized=False)
mnist_img.imshow(sample_mnist.squeeze().cpu(), cmap='gray')
mnist_img.axis('off')
mnist_img.set_title('MNIST Truth: {} Prediction: {}'.format(int(truth_mnist), int(np.argmax(pred_mnist))))
sample_notmnist, truth_notmnist = get_sample(notmnist_set, sample_type='notmnist')
pred_notmnist, epi_notmnist_norm, ale_notmnist_norm = get_uncertainty_per_image(net, sample_notmnist, T=25, normalized=True)
pred_notmnist, epi_notmnist_soft, ale_notmnist_soft = get_uncertainty_per_image(net, sample_notmnist, T=25, normalized=False)
notmnist_img.imshow(sample_notmnist.squeeze().cpu(), cmap='gray')
notmnist_img.axis('off')
notmnist_img.set_title('notMNIST Truth: {}({}) Prediction: {}({})'.format(
int(truth_notmnist), chr(65 + truth_notmnist), int(np.argmax(pred_notmnist)), chr(65 + np.argmax(pred_notmnist))))
x = list(range(10))
data = pd.DataFrame({
'epistemic_norm': np.hstack([epi_mnist_norm, epi_notmnist_norm]),
'aleatoric_norm': np.hstack([ale_mnist_norm, ale_notmnist_norm]),
'epistemic_soft': np.hstack([epi_mnist_soft, epi_notmnist_soft]),
'aleatoric_soft': np.hstack([ale_mnist_soft, ale_notmnist_soft]),
'category': np.hstack([x, x]),
'dataset': np.hstack([['MNIST']*10, ['notMNIST']*10])
})
print(data)
sns.barplot(x='category', y='epistemic_norm', hue='dataset', data=data, ax=epi_stats_norm)
sns.barplot(x='category', y='aleatoric_norm', hue='dataset', data=data, ax=ale_stats_norm)
epi_stats_norm.set_title('Epistemic Uncertainty (Normalized)')
ale_stats_norm.set_title('Aleatoric Uncertainty (Normalized)')
sns.barplot(x='category', y='epistemic_soft', hue='dataset', data=data, ax=epi_stats_soft)
sns.barplot(x='category', y='aleatoric_soft', hue='dataset', data=data, ax=ale_stats_soft)
epi_stats_soft.set_title('Epistemic Uncertainty (Softmax)')
ale_stats_soft.set_title('Aleatoric Uncertainty (Softmax)')
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description = "PyTorch Uncertainty Estimation b/w MNIST and notMNIST")
parser.add_argument('--net_type', default='lenet', type=str, help='model')
parser.add_argument('--weights_path', default='checkpoints/MNIST/bayesian/model_lenet.pt', type=str, help='weights for model')
parser.add_argument('--notmnist_dir', default='data/notMNIST_small/', type=str, help='weights for model')
args = parser.parse_args()
run(args.net_type, args.weights_path, args.notmnist_dir)