New class to make Gaussian noise
This commit is contained in:
parent
146521a811
commit
941cb7b00d
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
from tqdm import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import datasets
|
||||
from torch.utils.data import Dataset
|
||||
|
@ -6,11 +7,24 @@ from torch.utils.data import DataLoader
|
|||
import torchvision.transforms as transforms
|
||||
|
||||
|
||||
class AddGaussianNoise(object):
|
||||
def __init__(self, mean=0., std=1.):
|
||||
self.std = std
|
||||
self.mean = mean
|
||||
|
||||
def __call__(self, tensor):
|
||||
return tensor + torch.randn(tensor.size()) * self.std + self.mean
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
|
||||
|
||||
|
||||
def get_mnist_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(28, padding=4),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (0.5,)),
|
||||
AddGaussianNoise(0., 0.99),
|
||||
])
|
||||
|
||||
transform_test = transforms.Compose([
|
||||
|
@ -36,5 +50,48 @@ def get_mnist_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
|
|||
return train_loader, test_loader, train_eval_loader
|
||||
|
||||
|
||||
def get_cifar_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
|
||||
transform_train = transforms.Compose([
|
||||
#transforms.RandomCrop(32, padding=4),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (0.5,)),
|
||||
AddGaussianNoise(0., 0.99),
|
||||
])
|
||||
|
||||
transform_test = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (0.5,)),
|
||||
])
|
||||
|
||||
train_loader = DataLoader(
|
||||
datasets.CIFAR10(root='.data/cifar', train=True, download=True, transform=transform_train), batch_size=batch_size,
|
||||
shuffle=True, num_workers=2, drop_last=True
|
||||
)
|
||||
|
||||
train_eval_loader = DataLoader(
|
||||
datasets.CIFAR10(root='.data/cifar', train=True, download=True, transform=transform_test),
|
||||
batch_size=test_batch_size, shuffle=True, num_workers=2, drop_last=True
|
||||
)
|
||||
|
||||
test_loader = DataLoader(
|
||||
datasets.CIFAR10(root='.data/cifar', train=False, download=True, transform=transform_test),
|
||||
batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
|
||||
)
|
||||
|
||||
return train_loader, test_loader, train_eval_loader
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test = get_mnist_loaders()
|
||||
train_loader, test_loader, train_eval_loader\
|
||||
= get_cifar_loaders()
|
||||
|
||||
#for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
|
||||
# sample_idx = torch.randint(len(data), size=(1,)).item()
|
||||
# img = data[sample_idx]
|
||||
# # print(data)
|
||||
|
||||
images, labels = next(iter(train_loader))
|
||||
#plt.imshow(images[0].reshape(3,32,32).transpose(0,2,3,1))
|
||||
plt.imshow(images[0])
|
||||
plt.show()
|
||||
#print(images[0].shape)
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue