New class to make Gaussian noise

This commit is contained in:
Eduardo Cueto-Mendoza 2024-09-20 09:26:56 +01:00
parent 146521a811
commit 941cb7b00d
5 changed files with 58 additions and 1 deletions

View File

@ -1,4 +1,5 @@
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import datasets
from torch.utils.data import Dataset
@ -6,11 +7,24 @@ from torch.utils.data import DataLoader
import torchvision.transforms as transforms
class AddGaussianNoise(object):
def __init__(self, mean=0., std=1.):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
def get_mnist_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
transform_train = transforms.Compose([
transforms.RandomCrop(28, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
AddGaussianNoise(0., 0.99),
])
transform_test = transforms.Compose([
@ -36,5 +50,48 @@ def get_mnist_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
return train_loader, test_loader, train_eval_loader
def get_cifar_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
transform_train = transforms.Compose([
#transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
AddGaussianNoise(0., 0.99),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
train_loader = DataLoader(
datasets.CIFAR10(root='.data/cifar', train=True, download=True, transform=transform_train), batch_size=batch_size,
shuffle=True, num_workers=2, drop_last=True
)
train_eval_loader = DataLoader(
datasets.CIFAR10(root='.data/cifar', train=True, download=True, transform=transform_test),
batch_size=test_batch_size, shuffle=True, num_workers=2, drop_last=True
)
test_loader = DataLoader(
datasets.CIFAR10(root='.data/cifar', train=False, download=True, transform=transform_test),
batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
)
return train_loader, test_loader, train_eval_loader
if __name__ == '__main__':
test = get_mnist_loaders()
train_loader, test_loader, train_eval_loader\
= get_cifar_loaders()
#for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
# sample_idx = torch.randint(len(data), size=(1,)).item()
# img = data[sample_idx]
# # print(data)
images, labels = next(iter(train_loader))
#plt.imshow(images[0].reshape(3,32,32).transpose(0,2,3,1))
plt.imshow(images[0])
plt.show()
#print(images[0].shape)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.