New class to make Gaussian noise

This commit is contained in:
Eduardo Cueto-Mendoza 2024-09-20 09:26:56 +01:00
parent 146521a811
commit 941cb7b00d
5 changed files with 58 additions and 1 deletions

View File

@ -1,4 +1,5 @@
import torch import torch
from tqdm import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from torchvision import datasets from torchvision import datasets
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -6,11 +7,24 @@ from torch.utils.data import DataLoader
import torchvision.transforms as transforms import torchvision.transforms as transforms
class AddGaussianNoise(object):
def __init__(self, mean=0., std=1.):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
def get_mnist_loaders(batch_size=128, test_batch_size=1000, perc=1.0): def get_mnist_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
transform_train = transforms.Compose([ transform_train = transforms.Compose([
transforms.RandomCrop(28, padding=4), transforms.RandomCrop(28, padding=4),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)), transforms.Normalize((0.5,), (0.5,)),
AddGaussianNoise(0., 0.99),
]) ])
transform_test = transforms.Compose([ transform_test = transforms.Compose([
@ -36,5 +50,48 @@ def get_mnist_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
return train_loader, test_loader, train_eval_loader return train_loader, test_loader, train_eval_loader
def get_cifar_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
transform_train = transforms.Compose([
#transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
AddGaussianNoise(0., 0.99),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
train_loader = DataLoader(
datasets.CIFAR10(root='.data/cifar', train=True, download=True, transform=transform_train), batch_size=batch_size,
shuffle=True, num_workers=2, drop_last=True
)
train_eval_loader = DataLoader(
datasets.CIFAR10(root='.data/cifar', train=True, download=True, transform=transform_test),
batch_size=test_batch_size, shuffle=True, num_workers=2, drop_last=True
)
test_loader = DataLoader(
datasets.CIFAR10(root='.data/cifar', train=False, download=True, transform=transform_test),
batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
)
return train_loader, test_loader, train_eval_loader
if __name__ == '__main__': if __name__ == '__main__':
test = get_mnist_loaders() train_loader, test_loader, train_eval_loader\
= get_cifar_loaders()
#for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
# sample_idx = torch.randint(len(data), size=(1,)).item()
# img = data[sample_idx]
# # print(data)
images, labels = next(iter(train_loader))
#plt.imshow(images[0].reshape(3,32,32).transpose(0,2,3,1))
plt.imshow(images[0])
plt.show()
#print(images[0].shape)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.