diff --git a/making_noise.py b/making_noise.py index f588680..9f7dff6 100644 --- a/making_noise.py +++ b/making_noise.py @@ -1,4 +1,5 @@ import torch +from tqdm import tqdm import matplotlib.pyplot as plt from torchvision import datasets from torch.utils.data import Dataset @@ -6,11 +7,24 @@ from torch.utils.data import DataLoader import torchvision.transforms as transforms +class AddGaussianNoise(object): + def __init__(self, mean=0., std=1.): + self.std = std + self.mean = mean + + def __call__(self, tensor): + return tensor + torch.randn(tensor.size()) * self.std + self.mean + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) + + def get_mnist_loaders(batch_size=128, test_batch_size=1000, perc=1.0): transform_train = transforms.Compose([ transforms.RandomCrop(28, padding=4), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), + AddGaussianNoise(0., 0.99), ]) transform_test = transforms.Compose([ @@ -36,5 +50,48 @@ def get_mnist_loaders(batch_size=128, test_batch_size=1000, perc=1.0): return train_loader, test_loader, train_eval_loader +def get_cifar_loaders(batch_size=128, test_batch_size=1000, perc=1.0): + transform_train = transforms.Compose([ + #transforms.RandomCrop(32, padding=4), + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)), + AddGaussianNoise(0., 0.99), + ]) + + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)), + ]) + + train_loader = DataLoader( + datasets.CIFAR10(root='.data/cifar', train=True, download=True, transform=transform_train), batch_size=batch_size, + shuffle=True, num_workers=2, drop_last=True + ) + + train_eval_loader = DataLoader( + datasets.CIFAR10(root='.data/cifar', train=True, download=True, transform=transform_test), + batch_size=test_batch_size, shuffle=True, num_workers=2, drop_last=True + ) + + test_loader = DataLoader( + datasets.CIFAR10(root='.data/cifar', train=False, download=True, transform=transform_test), + batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True + ) + + return train_loader, test_loader, train_eval_loader + + if __name__ == '__main__': - test = get_mnist_loaders() + train_loader, test_loader, train_eval_loader\ + = get_cifar_loaders() + + #for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)): + # sample_idx = torch.randint(len(data), size=(1,)).item() + # img = data[sample_idx] + # # print(data) + + images, labels = next(iter(train_loader)) + #plt.imshow(images[0].reshape(3,32,32).transpose(0,2,3,1)) + plt.imshow(images[0]) + plt.show() + #print(images[0].shape) diff --git a/t10k-images-idx3-ubyte.gz b/t10k-images-idx3-ubyte.gz deleted file mode 100644 index 84332eb..0000000 Binary files a/t10k-images-idx3-ubyte.gz and /dev/null differ diff --git a/t10k-labels-idx1-ubyte.gz b/t10k-labels-idx1-ubyte.gz deleted file mode 100644 index 0538792..0000000 Binary files a/t10k-labels-idx1-ubyte.gz and /dev/null differ diff --git a/train-images-idx3-ubyte.gz b/train-images-idx3-ubyte.gz deleted file mode 100644 index 692cb8c..0000000 Binary files a/train-images-idx3-ubyte.gz and /dev/null differ diff --git a/train-labels-idx1-ubyte.gz b/train-labels-idx1-ubyte.gz deleted file mode 100644 index 7abdabb..0000000 Binary files a/train-labels-idx1-ubyte.gz and /dev/null differ