New class to make Gaussian noise
This commit is contained in:
parent
146521a811
commit
941cb7b00d
|
@ -1,4 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from torchvision import datasets
|
from torchvision import datasets
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
@ -6,11 +7,24 @@ from torch.utils.data import DataLoader
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
|
||||||
|
class AddGaussianNoise(object):
|
||||||
|
def __init__(self, mean=0., std=1.):
|
||||||
|
self.std = std
|
||||||
|
self.mean = mean
|
||||||
|
|
||||||
|
def __call__(self, tensor):
|
||||||
|
return tensor + torch.randn(tensor.size()) * self.std + self.mean
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
|
||||||
|
|
||||||
|
|
||||||
def get_mnist_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
|
def get_mnist_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
|
||||||
transform_train = transforms.Compose([
|
transform_train = transforms.Compose([
|
||||||
transforms.RandomCrop(28, padding=4),
|
transforms.RandomCrop(28, padding=4),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.5,), (0.5,)),
|
transforms.Normalize((0.5,), (0.5,)),
|
||||||
|
AddGaussianNoise(0., 0.99),
|
||||||
])
|
])
|
||||||
|
|
||||||
transform_test = transforms.Compose([
|
transform_test = transforms.Compose([
|
||||||
|
@ -36,5 +50,48 @@ def get_mnist_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
|
||||||
return train_loader, test_loader, train_eval_loader
|
return train_loader, test_loader, train_eval_loader
|
||||||
|
|
||||||
|
|
||||||
|
def get_cifar_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
|
||||||
|
transform_train = transforms.Compose([
|
||||||
|
#transforms.RandomCrop(32, padding=4),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5,), (0.5,)),
|
||||||
|
AddGaussianNoise(0., 0.99),
|
||||||
|
])
|
||||||
|
|
||||||
|
transform_test = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5,), (0.5,)),
|
||||||
|
])
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
datasets.CIFAR10(root='.data/cifar', train=True, download=True, transform=transform_train), batch_size=batch_size,
|
||||||
|
shuffle=True, num_workers=2, drop_last=True
|
||||||
|
)
|
||||||
|
|
||||||
|
train_eval_loader = DataLoader(
|
||||||
|
datasets.CIFAR10(root='.data/cifar', train=True, download=True, transform=transform_test),
|
||||||
|
batch_size=test_batch_size, shuffle=True, num_workers=2, drop_last=True
|
||||||
|
)
|
||||||
|
|
||||||
|
test_loader = DataLoader(
|
||||||
|
datasets.CIFAR10(root='.data/cifar', train=False, download=True, transform=transform_test),
|
||||||
|
batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, test_loader, train_eval_loader
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test = get_mnist_loaders()
|
train_loader, test_loader, train_eval_loader\
|
||||||
|
= get_cifar_loaders()
|
||||||
|
|
||||||
|
#for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
|
||||||
|
# sample_idx = torch.randint(len(data), size=(1,)).item()
|
||||||
|
# img = data[sample_idx]
|
||||||
|
# # print(data)
|
||||||
|
|
||||||
|
images, labels = next(iter(train_loader))
|
||||||
|
#plt.imshow(images[0].reshape(3,32,32).transpose(0,2,3,1))
|
||||||
|
plt.imshow(images[0])
|
||||||
|
plt.show()
|
||||||
|
#print(images[0].shape)
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue