185 lines
6.9 KiB
Python
185 lines
6.9 KiB
Python
|
import argparse
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
import pandas as pd
|
||
|
import seaborn as sns
|
||
|
from PIL import Image
|
||
|
import torchvision
|
||
|
from torch.nn import functional as F
|
||
|
import torchvision.transforms as transforms
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
import data
|
||
|
from main_bayesian import getModel
|
||
|
import config_bayesian as cfg
|
||
|
|
||
|
|
||
|
# CUDA settings
|
||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
|
|
||
|
mnist_set = None
|
||
|
notmnist_set = None
|
||
|
|
||
|
transform = transforms.Compose([
|
||
|
transforms.ToPILImage(),
|
||
|
transforms.Resize((32, 32)),
|
||
|
transforms.ToTensor(),
|
||
|
])
|
||
|
|
||
|
|
||
|
def init_dataset(notmnist_dir):
|
||
|
global mnist_set
|
||
|
global notmnist_set
|
||
|
mnist_set, _, _, _ = data.getDataset('MNIST')
|
||
|
notmnist_set = torchvision.datasets.ImageFolder(root=notmnist_dir)
|
||
|
|
||
|
|
||
|
def get_uncertainty_per_image(model, input_image, T=15, normalized=False):
|
||
|
input_image = input_image.unsqueeze(0)
|
||
|
input_images = input_image.repeat(T, 1, 1, 1)
|
||
|
|
||
|
net_out, _ = model(input_images)
|
||
|
pred = torch.mean(net_out, dim=0).cpu().detach().numpy()
|
||
|
if normalized:
|
||
|
prediction = F.softplus(net_out)
|
||
|
p_hat = prediction / torch.sum(prediction, dim=1).unsqueeze(1)
|
||
|
else:
|
||
|
p_hat = F.softmax(net_out, dim=1)
|
||
|
p_hat = p_hat.detach().cpu().numpy()
|
||
|
p_bar = np.mean(p_hat, axis=0)
|
||
|
|
||
|
temp = p_hat - np.expand_dims(p_bar, 0)
|
||
|
epistemic = np.dot(temp.T, temp) / T
|
||
|
epistemic = np.diag(epistemic)
|
||
|
|
||
|
aleatoric = np.diag(p_bar) - (np.dot(p_hat.T, p_hat) / T)
|
||
|
aleatoric = np.diag(aleatoric)
|
||
|
|
||
|
return pred, epistemic, aleatoric
|
||
|
|
||
|
|
||
|
def get_uncertainty_per_batch(model, batch, T=15, normalized=False):
|
||
|
batch_predictions = []
|
||
|
net_outs = []
|
||
|
batches = batch.unsqueeze(0).repeat(T, 1, 1, 1, 1)
|
||
|
|
||
|
preds = []
|
||
|
epistemics = []
|
||
|
aleatorics = []
|
||
|
|
||
|
for i in range(T): # for T batches
|
||
|
net_out, _ = model(batches[i].cuda())
|
||
|
net_outs.append(net_out)
|
||
|
if normalized:
|
||
|
prediction = F.softplus(net_out)
|
||
|
prediction = prediction / torch.sum(prediction, dim=1).unsqueeze(1)
|
||
|
else:
|
||
|
prediction = F.softmax(net_out, dim=1)
|
||
|
batch_predictions.append(prediction)
|
||
|
|
||
|
for sample in range(batch.shape[0]):
|
||
|
# for each sample in a batch
|
||
|
pred = torch.cat([a_batch[sample].unsqueeze(0) for a_batch in net_outs], dim=0)
|
||
|
pred = torch.mean(pred, dim=0)
|
||
|
preds.append(pred)
|
||
|
|
||
|
p_hat = torch.cat([a_batch[sample].unsqueeze(0) for a_batch in batch_predictions], dim=0).detach().cpu().numpy()
|
||
|
p_bar = np.mean(p_hat, axis=0)
|
||
|
|
||
|
temp = p_hat - np.expand_dims(p_bar, 0)
|
||
|
epistemic = np.dot(temp.T, temp) / T
|
||
|
epistemic = np.diag(epistemic)
|
||
|
epistemics.append(epistemic)
|
||
|
|
||
|
aleatoric = np.diag(p_bar) - (np.dot(p_hat.T, p_hat) / T)
|
||
|
aleatoric = np.diag(aleatoric)
|
||
|
aleatorics.append(aleatoric)
|
||
|
|
||
|
epistemic = np.vstack(epistemics) # (batch_size, categories)
|
||
|
aleatoric = np.vstack(aleatorics) # (batch_size, categories)
|
||
|
preds = torch.cat([i.unsqueeze(0) for i in preds]).cpu().detach().numpy() # (batch_size, categories)
|
||
|
|
||
|
return preds, epistemic, aleatoric
|
||
|
|
||
|
|
||
|
def get_sample(dataset, sample_type='mnist'):
|
||
|
idx = np.random.randint(len(dataset.targets))
|
||
|
if sample_type=='mnist':
|
||
|
sample = dataset.data[idx]
|
||
|
truth = dataset.targets[idx]
|
||
|
else:
|
||
|
path, truth = dataset.samples[idx]
|
||
|
sample = torch.from_numpy(np.array(Image.open(path)))
|
||
|
|
||
|
sample = sample.unsqueeze(0)
|
||
|
sample = transform(sample)
|
||
|
return sample.to(device), truth
|
||
|
|
||
|
|
||
|
def run(net_type, weight_path, notmnist_dir):
|
||
|
init_dataset(notmnist_dir)
|
||
|
|
||
|
layer_type = cfg.layer_type
|
||
|
activation_type = cfg.activation_type
|
||
|
|
||
|
net = getModel(net_type, 1, 10, priors=None, layer_type=layer_type, activation_type=activation_type)
|
||
|
net.load_state_dict(torch.load(weight_path))
|
||
|
net.train()
|
||
|
net.to(device)
|
||
|
|
||
|
fig = plt.figure()
|
||
|
fig.suptitle('Uncertainty Estimation', fontsize='x-large')
|
||
|
mnist_img = fig.add_subplot(321)
|
||
|
notmnist_img = fig.add_subplot(322)
|
||
|
epi_stats_norm = fig.add_subplot(323)
|
||
|
ale_stats_norm = fig.add_subplot(324)
|
||
|
epi_stats_soft = fig.add_subplot(325)
|
||
|
ale_stats_soft = fig.add_subplot(326)
|
||
|
|
||
|
sample_mnist, truth_mnist = get_sample(mnist_set)
|
||
|
pred_mnist, epi_mnist_norm, ale_mnist_norm = get_uncertainty_per_image(net, sample_mnist, T=25, normalized=True)
|
||
|
pred_mnist, epi_mnist_soft, ale_mnist_soft = get_uncertainty_per_image(net, sample_mnist, T=25, normalized=False)
|
||
|
mnist_img.imshow(sample_mnist.squeeze().cpu(), cmap='gray')
|
||
|
mnist_img.axis('off')
|
||
|
mnist_img.set_title('MNIST Truth: {} Prediction: {}'.format(int(truth_mnist), int(np.argmax(pred_mnist))))
|
||
|
|
||
|
sample_notmnist, truth_notmnist = get_sample(notmnist_set, sample_type='notmnist')
|
||
|
pred_notmnist, epi_notmnist_norm, ale_notmnist_norm = get_uncertainty_per_image(net, sample_notmnist, T=25, normalized=True)
|
||
|
pred_notmnist, epi_notmnist_soft, ale_notmnist_soft = get_uncertainty_per_image(net, sample_notmnist, T=25, normalized=False)
|
||
|
notmnist_img.imshow(sample_notmnist.squeeze().cpu(), cmap='gray')
|
||
|
notmnist_img.axis('off')
|
||
|
notmnist_img.set_title('notMNIST Truth: {}({}) Prediction: {}({})'.format(
|
||
|
int(truth_notmnist), chr(65 + truth_notmnist), int(np.argmax(pred_notmnist)), chr(65 + np.argmax(pred_notmnist))))
|
||
|
|
||
|
x = list(range(10))
|
||
|
data = pd.DataFrame({
|
||
|
'epistemic_norm': np.hstack([epi_mnist_norm, epi_notmnist_norm]),
|
||
|
'aleatoric_norm': np.hstack([ale_mnist_norm, ale_notmnist_norm]),
|
||
|
'epistemic_soft': np.hstack([epi_mnist_soft, epi_notmnist_soft]),
|
||
|
'aleatoric_soft': np.hstack([ale_mnist_soft, ale_notmnist_soft]),
|
||
|
'category': np.hstack([x, x]),
|
||
|
'dataset': np.hstack([['MNIST']*10, ['notMNIST']*10])
|
||
|
})
|
||
|
print(data)
|
||
|
sns.barplot(x='category', y='epistemic_norm', hue='dataset', data=data, ax=epi_stats_norm)
|
||
|
sns.barplot(x='category', y='aleatoric_norm', hue='dataset', data=data, ax=ale_stats_norm)
|
||
|
epi_stats_norm.set_title('Epistemic Uncertainty (Normalized)')
|
||
|
ale_stats_norm.set_title('Aleatoric Uncertainty (Normalized)')
|
||
|
|
||
|
sns.barplot(x='category', y='epistemic_soft', hue='dataset', data=data, ax=epi_stats_soft)
|
||
|
sns.barplot(x='category', y='aleatoric_soft', hue='dataset', data=data, ax=ale_stats_soft)
|
||
|
epi_stats_soft.set_title('Epistemic Uncertainty (Softmax)')
|
||
|
ale_stats_soft.set_title('Aleatoric Uncertainty (Softmax)')
|
||
|
|
||
|
plt.show()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
parser = argparse.ArgumentParser(description = "PyTorch Uncertainty Estimation b/w MNIST and notMNIST")
|
||
|
parser.add_argument('--net_type', default='lenet', type=str, help='model')
|
||
|
parser.add_argument('--weights_path', default='checkpoints/MNIST/bayesian/model_lenet.pt', type=str, help='weights for model')
|
||
|
parser.add_argument('--notmnist_dir', default='data/notMNIST_small/', type=str, help='weights for model')
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
run(args.net_type, args.weights_path, args.notmnist_dir)
|