bayesiancnn/data/data.py

406 lines
13 KiB
Python
Executable File

import random
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.data.sampler import SubsetRandomSampler
class AddNoNoise(object):
def __init__(self, mean=0.0, std=1.0):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor
def __repr__(self):
return self.__class__.__name__ + "No noise"
class AddGaussianNoise(object):
def __init__(self, mean=0.0, std=1.0):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + "(mean={0}, std={1})".format(
self.mean, self.std
)
class AddRaleighNoise(object):
def __init__(self, a=0.0, b=0.0):
self.std = (b * (4 - np.pi)) / 4
self.mean = a + np.sqrt((np.pi * b) / 4)
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + "(mean={0}, std={1})".format(
self.mean, self.std
)
class AddErlangNoise(object):
def __init__(self, a=0.0, b=0.0):
if a == 0.0:
self.std = 0.0
self.mean = 0.0
else:
self.std = b / a
self.mean = b / (2 * a)
def __call__(self, tensor):
if self.mean == 0.0:
return tensor * self.mean
else:
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + "(mean={0}, std={1})".format(
self.mean, self.std
)
class AddExponentialNoise(object):
def __init__(self, a=0.0, b=0):
if a == 0.0:
self.mean = 0.0
else:
self.std = 1 / (2 * a)
self.mean = 1 / a
def __call__(self, tensor):
if self.mean == 0.0:
return tensor * self.mean
else:
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + "(mean={0}, std={1})".format(
self.mean, self.std
)
class AddUniformNoise(object):
def __init__(self, a=0.0, b=0.0):
if a == 0.0:
self.std = 0.0
self.mean = 0.0
else:
self.std = (b - a) ** 2 / 12
self.mean = (b + a) / 2
def __call__(self, tensor):
if self.mean == 0.0:
return tensor * self.mean
else:
return tensor + (torch.randn(tensor.size()) * self.std + self.mean)
def __repr__(self):
return self.__class__.__name__ + "(mean={0}, std={1})".format(
self.mean, self.std
)
class AddImpulseNoise(object):
def __init__(self, a=0.0, b=0):
self.value = a
def __call__(self, tensor):
if random.gauss(0, 1) > 0:
return tensor * self.value
elif random.gauss(0, 1) < 0:
return tensor * (-1 * self.value)
else:
return tensor * 0.0
def __repr__(self):
return self.__class__.__name__ + "(a={0})".format(self.value)
class CustomDataset(Dataset):
def __init__(self, data, labels, transform=None):
self.data = data
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
sample = self.data[idx]
label = self.labels[idx]
if self.transform:
sample = self.transform(sample)
return sample, label
def extract_classes(dataset, classes):
idx = torch.zeros_like(dataset.targets, dtype=torch.bool)
for target in classes:
idx = idx | (dataset.targets == target)
data, targets = dataset.data[idx], dataset.targets[idx]
return data, targets
def getDataset(dataset, noise=None, mean=0.0, std=0.0):
"""Function to get training datasets"""
noise_type = None
if noise is None:
# print("No noise added")
noise_type = AddNoNoise
elif noise == "gaussian":
noise_type = AddGaussianNoise
elif noise == "raleigh":
noise_type = AddRaleighNoise
elif noise == "erlang":
noise_type = AddErlangNoise
elif noise == "exponential":
noise_type = AddExponentialNoise
elif noise == "uniform":
noise_type = AddUniformNoise
elif noise == "impulse":
noise_type = AddImpulseNoise
print(f"{noise_type} noise added")
transform_split_mnist = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize((32, 32)),
transforms.ToTensor(),
noise_type(mean, std),
]
)
transform_mnist = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ToTensor(),
noise_type(mean, std),
]
)
transform_cifar = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
noise_type(mean, std),
]
)
if dataset == "CIFAR10":
trainset = torchvision.datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform_cifar
)
testset = torchvision.datasets.CIFAR10(
root="./data", train=False, download=True, transform=transform_cifar
)
num_classes = 10
inputs = 3
elif dataset == "CIFAR100":
trainset = torchvision.datasets.CIFAR100(
root="./data", train=True, download=True, transform=transform_cifar
)
testset = torchvision.datasets.CIFAR100(
root="./data", train=False, download=True, transform=transform_cifar
)
num_classes = 100
inputs = 3
elif dataset == "MNIST":
trainset = torchvision.datasets.MNIST(
root="./data", train=True, download=True, transform=transform_mnist
)
testset = torchvision.datasets.MNIST(
root="./data", train=False, download=True, transform=transform_mnist
)
num_classes = 10
inputs = 1
elif dataset == "SplitMNIST-2.1":
trainset = torchvision.datasets.MNIST(
root="./data", train=True, download=True, transform=transform_mnist
)
testset = torchvision.datasets.MNIST(
root="./data", train=False, download=True, transform=transform_mnist
)
train_data, train_targets = extract_classes(trainset, [0, 1, 2, 3, 4])
test_data, test_targets = extract_classes(testset, [0, 1, 2, 3, 4])
trainset = CustomDataset(
train_data, train_targets, transform=transform_split_mnist
)
testset = CustomDataset(
test_data, test_targets, transform=transform_split_mnist
)
num_classes = 5
inputs = 1
elif dataset == "SplitMNIST-2.2":
trainset = torchvision.datasets.MNIST(
root="./data", train=True, download=True, transform=transform_mnist
)
testset = torchvision.datasets.MNIST(
root="./data", train=False, download=True, transform=transform_mnist
)
train_data, train_targets = extract_classes(trainset, [5, 6, 7, 8, 9])
test_data, test_targets = extract_classes(testset, [5, 6, 7, 8, 9])
train_targets -= 5 # Mapping target 5-9 to 0-4
test_targets -= 5 # Hence, add 5 after prediction
trainset = CustomDataset(
train_data, train_targets, transform=transform_split_mnist
)
testset = CustomDataset(
test_data, test_targets, transform=transform_split_mnist
)
num_classes = 5
inputs = 1
elif dataset == "SplitMNIST-5.1":
trainset = torchvision.datasets.MNIST(
root="./data", train=True, download=True, transform=transform_mnist
)
testset = torchvision.datasets.MNIST(
root="./data", train=False, download=True, transform=transform_mnist
)
train_data, train_targets = extract_classes(trainset, [0, 1])
test_data, test_targets = extract_classes(testset, [0, 1])
trainset = CustomDataset(
train_data, train_targets, transform=transform_split_mnist
)
testset = CustomDataset(
test_data, test_targets, transform=transform_split_mnist
)
num_classes = 2
inputs = 1
elif dataset == "SplitMNIST-5.2":
trainset = torchvision.datasets.MNIST(
root="./data", train=True, download=True, transform=transform_mnist
)
testset = torchvision.datasets.MNIST(
root="./data", train=False, download=True, transform=transform_mnist
)
train_data, train_targets = extract_classes(trainset, [2, 3])
test_data, test_targets = extract_classes(testset, [2, 3])
train_targets -= 2 # Mapping target 2-3 to 0-1
test_targets -= 2 # Hence, add 2 after prediction
trainset = CustomDataset(
train_data, train_targets, transform=transform_split_mnist
)
testset = CustomDataset(
test_data, test_targets, transform=transform_split_mnist
)
num_classes = 2
inputs = 1
elif dataset == "SplitMNIST-5.3":
trainset = torchvision.datasets.MNIST(
root="./data", train=True, download=True, transform=transform_mnist
)
testset = torchvision.datasets.MNIST(
root="./data", train=False, download=True, transform=transform_mnist
)
train_data, train_targets = extract_classes(trainset, [4, 5])
test_data, test_targets = extract_classes(testset, [4, 5])
train_targets -= 4 # Mapping target 4-5 to 0-1
test_targets -= 4 # Hence, add 4 after prediction
trainset = CustomDataset(
train_data, train_targets, transform=transform_split_mnist
)
testset = CustomDataset(
test_data, test_targets, transform=transform_split_mnist
)
num_classes = 2
inputs = 1
elif dataset == "SplitMNIST-5.4":
trainset = torchvision.datasets.MNIST(
root="./data", train=True, download=True, transform=transform_mnist
)
testset = torchvision.datasets.MNIST(
root="./data", train=False, download=True, transform=transform_mnist
)
train_data, train_targets = extract_classes(trainset, [6, 7])
test_data, test_targets = extract_classes(testset, [6, 7])
train_targets -= 6 # Mapping target 6-7 to 0-1
test_targets -= 6 # Hence, add 6 after prediction
trainset = CustomDataset(
train_data, train_targets, transform=transform_split_mnist
)
testset = CustomDataset(
test_data, test_targets, transform=transform_split_mnist
)
num_classes = 2
inputs = 1
elif dataset == "SplitMNIST-5.5":
trainset = torchvision.datasets.MNIST(
root="./data", train=True, download=True, transform=transform_mnist
)
testset = torchvision.datasets.MNIST(
root="./data", train=False, download=True, transform=transform_mnist
)
train_data, train_targets = extract_classes(trainset, [8, 9])
test_data, test_targets = extract_classes(testset, [8, 9])
train_targets -= 8 # Mapping target 8-9 to 0-1
test_targets -= 8 # Hence, add 8 after prediction
trainset = CustomDataset(
train_data, train_targets, transform=transform_split_mnist
)
testset = CustomDataset(
test_data, test_targets, transform=transform_split_mnist
)
num_classes = 2
inputs = 1
return trainset, testset, inputs, num_classes
def getDataloader(trainset, testset, valid_size, batch_size, num_workers):
num_train = len(trainset)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers
)
valid_loader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers
)
test_loader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, num_workers=num_workers
)
return train_loader, valid_loader, test_loader