initial commit
This commit is contained in:
commit
146521a811
|
@ -0,0 +1,40 @@
|
||||||
|
import torch
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from torchvision import datasets
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
|
||||||
|
def get_mnist_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
|
||||||
|
transform_train = transforms.Compose([
|
||||||
|
transforms.RandomCrop(28, padding=4),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5,), (0.5,)),
|
||||||
|
])
|
||||||
|
|
||||||
|
transform_test = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize((0.5,), (0.5,)),
|
||||||
|
])
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
|
||||||
|
shuffle=True, num_workers=2, drop_last=True
|
||||||
|
)
|
||||||
|
|
||||||
|
train_eval_loader = DataLoader(
|
||||||
|
datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
|
||||||
|
batch_size=test_batch_size, shuffle=True, num_workers=2, drop_last=True
|
||||||
|
)
|
||||||
|
|
||||||
|
test_loader = DataLoader(
|
||||||
|
datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),
|
||||||
|
batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, test_loader, train_eval_loader
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test = get_mnist_loaders()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue