initial commit
This commit is contained in:
commit
146521a811
|
@ -0,0 +1,40 @@
|
|||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import datasets
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
|
||||
def get_mnist_loaders(batch_size=128, test_batch_size=1000, perc=1.0):
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(28, padding=4),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (0.5,)),
|
||||
])
|
||||
|
||||
transform_test = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (0.5,)),
|
||||
])
|
||||
|
||||
train_loader = DataLoader(
|
||||
datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_train), batch_size=batch_size,
|
||||
shuffle=True, num_workers=2, drop_last=True
|
||||
)
|
||||
|
||||
train_eval_loader = DataLoader(
|
||||
datasets.MNIST(root='.data/mnist', train=True, download=True, transform=transform_test),
|
||||
batch_size=test_batch_size, shuffle=True, num_workers=2, drop_last=True
|
||||
)
|
||||
|
||||
test_loader = DataLoader(
|
||||
datasets.MNIST(root='.data/mnist', train=False, download=True, transform=transform_test),
|
||||
batch_size=test_batch_size, shuffle=False, num_workers=2, drop_last=True
|
||||
)
|
||||
|
||||
return train_loader, test_loader, train_eval_loader
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test = get_mnist_loaders()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue