2020-05-02 21:40:17 +00:00
|
|
|
{
|
|
|
|
"cells": [
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 1,
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"import torch\n",
|
|
|
|
"import torch.nn as nn\n",
|
|
|
|
"import torchvision.transforms as transforms\n",
|
|
|
|
"import torchvision.datasets as datasets"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 2,
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"mean_gray = 0.1307\n",
|
|
|
|
"stddev_gray = 0.3081\n",
|
|
|
|
"\n",
|
|
|
|
"# input[channel] = (input[channel]-meaan[channel]) / std[channel]\n",
|
|
|
|
"\n",
|
|
|
|
"transforms = transforms.Compose([transforms.ToTensor(),\n",
|
|
|
|
" transforms.Normalize((mean_gray,),(stddev_gray,))])\n",
|
|
|
|
"\n",
|
|
|
|
"train_dataset = datasets.MNIST(root='./data', \n",
|
|
|
|
" train=True, \n",
|
|
|
|
" transform=transforms, \n",
|
|
|
|
" download=True)\n",
|
|
|
|
"\n",
|
|
|
|
"test_dataset = datasets.MNIST(root='./data', \n",
|
|
|
|
" train=False, \n",
|
|
|
|
" transform=transforms)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 3,
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
2020-05-03 18:26:33 +00:00
|
|
|
"<matplotlib.image.AxesImage at 0x7f66cc282e90>"
|
2020-05-02 21:40:17 +00:00
|
|
|
]
|
|
|
|
},
|
|
|
|
"execution_count": 3,
|
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAANr0lEQVR4nO3db6xUdX7H8c9HujwRJFBSvLK0LqvGbGrKNjdYLTE2usTyBPeBm0VtaFy9mKzJqg0tUiMas2raWh+ZNazKotmy2UR2NdBk15JVW2OIV0MFvd31lqALuUIURFcfbJFvH9yDueA9Zy4zZ+YM9/t+JTczc74z53wz4cP5O+fniBCA6e+sphsA0BuEHUiCsANJEHYgCcIOJPEHvVyYbQ79A10WEZ5sekdrdtvX2P617VHb6zqZF4Ducrvn2W3PkPQbSd+QtF/Sq5JWRcRbFZ9hzQ50WTfW7EsljUbE3oj4vaSfSFrZwfwAdFEnYV8o6bcTXu8vpp3E9pDtYdvDHSwLQIe6foAuIjZK2iixGQ80qZM1+wFJiya8/nIxDUAf6iTsr0q60PZXbM+U9G1Jz9XTFoC6tb0ZHxHHbN8m6ReSZkh6MiLerK0zALVq+9RbWwtjnx3ouq5cVAPgzEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBI9HbIZmOiiiy6qrD/22GOV9RtuuKGyPjY2dto9TWes2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgiWlznn327NmV9VmzZlXWjx49Wln/9NNPT7snVFuxYkVl/Yorrqis33zzzZX1Bx98sLR27Nixys9ORx2F3fY+SR9L+kzSsYgYrKMpAPWrY83+VxHxfg3zAdBF7LMDSXQa9pD0S9uv2R6a7A22h2wP2x7ucFkAOtDpZvyyiDhg+48kPW/7fyLipYlviIiNkjZKku3ocHkA2tTRmj0iDhSPhyT9TNLSOpoCUL+2w277bNuzTzyXtFzSnroaA1AvR7S3ZW17scbX5tL47sC/RcT3W3yma5vx999/f2X9rrvuqqyvXbu2sv7II4+cdk+otmzZssr6Cy+80NH8L7744tLa6OhoR/PuZxHhyaa3vc8eEXsl/VnbHQHoKU69AUkQdiAJwg4kQdiBJAg7kMS0+YlrpzZs2FBZ37t3b2nt2WefrbudFM4999ymW0iFNTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJMF59kKrW01v2rSptLZ8+fLKzw4P570jV9X3euedd3Z12dddd11preo209MVa3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSGLanGfft29fV+d/zjnnlNbuu+++ys/eeOONlfUjR4601dOZ4IILLiitLV3KmCK9xJodSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Joe8jmthbWxSGbZ8yYUVlfv359Zb3VfeM7ceutt1bWH3/88a4tu2nnnXdeaa3VkMyLFy/uaNkM2Xyylmt220/aPmR7z4Rp82w/b/vt4nFunc0CqN9UNuN/JOmaU6atk7QjIi6UtKN4DaCPtQx7RLwk6fApk1dK2lw83yzp2pr7AlCzdq+NXxARY8Xz9yQtKHuj7SFJQ20uB0BNOv4hTERE1YG3iNgoaaPU3QN0AKq1e+rtoO0BSSoeD9XXEoBuaDfsz0laXTxfLYkxi4E+1/I8u+0tkq6UNF/SQUkbJP1c0k8l/bGkdyR9KyJOPYg32bwa24yfM2dOZX3nzp2V9arfZbeye/fuyvrVV19dWf/ggw/aXnbTlixZUlrr9v30Oc9+spb77BGxqqR0VUcdAegpLpcFkiDsQBKEHUiCsANJEHYgiWlzK+lWjh49Wll/+eWXK+udnHq75JJLKuuLFi2qrHfz1NvMmTMr62vWrOlo/lXDJqO3WLMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBJpzrO38sorr1TWV69eXVnvxGWXXVZZ37VrV2X98ssvb6smSbNmzaqs33333ZX1Jo2MjFTWp/NQ2O1gzQ4kQdiBJAg7kARhB5Ig7EAShB1IgrADSUybIZu77emnny6tXX/99T3spF5nnVX9//3x48d71En9hobKRx174oknethJb7U9ZDOA6YGwA0kQdiAJwg4kQdiBJAg7kARhB5LgPPsUNTn0cDfZk56S/Vwv/33UbdOmTaW1W265pYed9Fbb59ltP2n7kO09E6bda/uA7V3F34o6mwVQv6lsxv9I0jWTTH8kIpYUf/9eb1sA6tYy7BHxkqTDPegFQBd1coDuNttvFJv5c8veZHvI9rDtM3fHFpgG2g37DyR9VdISSWOSHi57Y0RsjIjBiBhsc1kAatBW2CPiYER8FhHHJf1Q0tJ62wJQt7bCbntgwstvStpT9l4A/aHlfeNtb5F0paT5tvdL2iDpSttLJIWkfZI6G8QbjRkdHa2stzrPvn379sr60aNHS2v33HNP5WdRr5Zhj4hVk0yevr/8B6YpLpcFkiDsQBKEHUiCsANJEHYgCYZsPgMcPlz904R33323tPbww6UXN0qStmzZ0lZPU1X102BOvfUWa3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSILz7FO0d+/e0tpTTz1V+dnFixdX1kdGRirrjz76aGV9zx5uJzCZ5cuXl9bmzi29k5ok6ciRI3W30zjW7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOfZp+ijjz4qrd1000097ARTtXDhwtLazJkze9hJf2DNDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcJ4dXfXhhx+W1sbGxio/OzAwUHc7n3vggQcq62vWVI9CfuzYsTrb6YmWa3bbi2z/yvZbtt+0/b1i+jzbz9t+u3isvhsAgEZNZTP+mKS/i4ivSfoLSd+1/TVJ6yTtiIgLJe0oXgPoUy3DHhFjEfF68fxjSSOSFkpaKWlz8bbNkq7tVpMAOnda++y2z5f0dUk7JS2IiBM7Xe9JWlDymSFJQ+23CKAOUz4ab3uWpGck3R4RJ/0qJCJCUkz2uYjYGBGDETHYUacAOjKlsNv+ksaD/uOI2FpMPmh7oKgPSDrUnRYB1MHjK+WKN9jW+D754Yi4fcL0f5b0QUQ8ZHudpHkR8fct5lW9MKRy6aWXVta3bt1aWV+wYNI9x1rMmTOnsv7JJ590bdmdighPNn0q++x/KelvJO22vauYtl7SQ5J+avs7kt6R9K06GgXQHS3DHhH/JWnS/ykkXVVvOwC6hctlgSQIO5AEYQeSIOxAEoQdSKLlefZaF8Z5dpyGwcHqiy63bdtWWZ8/f37by77qquoTTS+++GLb8+62svPsrNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAluJY2+NTw8XFm/4447Kutr164trW3fvr2jZZ+JWLMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBL8nh2YZvg9O5AcYQeSIOxAEoQdSIKwA0kQdiAJwg4k0TLsthfZ/pXtt2y/aft7xfR7bR+wvav4W9H9dgG0q+VFNbYHJA1ExOu2Z0t6TdK1Gh+P/XcR8S9TXhgX1QBdV3ZRzVTGZx+TNFY8/9j2iKSF9bYHoNtOa5/d9vmSvi5pZzHpNttv2H7S9tySzwzZHrY9/e7zA5xBpnxtvO1Zkl6U9P2I2Gp7gaT3JYWk+zW+qX9Ti3mwGQ90Wdlm/JTCbvtLkrZJ+kVE/Osk9fMlbYuIP20xH8IOdFnbP4SxbUlPSBqZGPTiwN0J35S0p9MmAXTPVI7GL5P0n5J2SzpeTF4vaZWkJRrfjN8naU1xMK9qXqzZgS7raDO+LoQd6D5+zw4kR9iBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUii5Q0na/a+pHcmvJ5fTOtH/dpbv/Yl0Vu76uztT8oKPf09+xcWbg9HxGBjDVTo1976tS+J3trVq97YjAeSIOxAEk2HfWPDy6/Sr731a18SvbW
|
|
|
|
"text/plain": [
|
|
|
|
"<Figure size 432x288 with 1 Axes>"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
"metadata": {
|
|
|
|
"needs_background": "light"
|
|
|
|
},
|
|
|
|
"output_type": "display_data"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"import matplotlib.pyplot as plt\n",
|
|
|
|
"random_img = train_dataset[20][0].numpy() * stddev_gray + mean_gray\n",
|
|
|
|
"plt.imshow(random_img.reshape(28,28), cmap='gray')"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 4,
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
2020-05-03 18:26:33 +00:00
|
|
|
"name": "stdout",
|
|
|
|
"output_type": "stream",
|
|
|
|
"text": [
|
|
|
|
"4\n"
|
|
|
|
]
|
2020-05-02 21:40:17 +00:00
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"print(train_dataset[20][1])"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": 5,
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
|
|
|
"source": [
|
|
|
|
"batch_size = 100\n",
|
|
|
|
"\n",
|
|
|
|
"train_load = torch.utils.data.DataLoader(dataset=train_dataset,\n",
|
|
|
|
" batch_size=batch_size,\n",
|
|
|
|
" shuffle=True)\n",
|
|
|
|
"\n",
|
|
|
|
"test_load = torch.utils.data.DataLoader(dataset=test_dataset,\n",
|
|
|
|
" batch_size=batch_size,\n",
|
|
|
|
" shuffle=True)\n"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
2020-05-03 18:26:33 +00:00
|
|
|
"execution_count": 6,
|
2020-05-02 21:40:17 +00:00
|
|
|
"metadata": {},
|
|
|
|
"outputs": [
|
|
|
|
{
|
|
|
|
"data": {
|
|
|
|
"text/plain": [
|
|
|
|
"100"
|
|
|
|
]
|
|
|
|
},
|
2020-05-03 18:26:33 +00:00
|
|
|
"execution_count": 6,
|
2020-05-02 21:40:17 +00:00
|
|
|
"metadata": {},
|
|
|
|
"output_type": "execute_result"
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"source": [
|
|
|
|
"len(test_load)"
|
|
|
|
]
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"cell_type": "code",
|
|
|
|
"execution_count": null,
|
|
|
|
"metadata": {},
|
|
|
|
"outputs": [],
|
2020-05-03 18:26:33 +00:00
|
|
|
"source": [
|
|
|
|
"class CNN(nn.Module):\n",
|
|
|
|
" def __init__(self):\n",
|
|
|
|
" super(CNN,self).__init__()\n",
|
|
|
|
" # Same padding means that input_size = output_size\n",
|
|
|
|
" # same_padding = (filter_size - 1) / 2\n",
|
|
|
|
" self.cnn1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=(3,3), stride=1, padding=1)\n",
|
|
|
|
" # The output from this layer has size:\n",
|
|
|
|
" # [(input_size - filter_size + 2*(padding)) / stride] + 1\n",
|
|
|
|
" self.batchnorm1 = nn.BatchNorm2d(8)\n",
|
|
|
|
" self.relu = nn.ReLU()\n",
|
|
|
|
" self.maxpool = nn.MaxPool2d(kernel_size=(2,2))\n",
|
|
|
|
" # Pooling output size is 28/2 = 14 (downsampling)\n",
|
|
|
|
" # Same padding size is (5 - 1)/2 = 2\n",
|
|
|
|
" self.cnn2 = nn.Conv2d(in_channels=8, out_channels=32, kernel_size=(5,5), stride=1, padding=2)\n",
|
|
|
|
" self.batchnorm2 = nn.BatchNorm2d(32)\n",
|
|
|
|
" # Pooling output size is 14/2 = 7 (downsampling)\n",
|
|
|
|
" # We have to flatten the output channels 32*7*7 = 1568\n",
|
|
|
|
" self.fc1 = nn.Linear(1568,600)\n",
|
|
|
|
" self.dropout = nn.Dropout(p=0.5)\n",
|
|
|
|
" self.fc2 = nn.Linear(600,10)\n",
|
|
|
|
" \n",
|
|
|
|
" def forward(self,x):\n",
|
|
|
|
" out = self.cnn1(x)\n",
|
|
|
|
" out = self.batchnorm1(out)\n",
|
|
|
|
" out = self.relu(out)\n",
|
|
|
|
" out = self.maxpool(out)\n",
|
|
|
|
" out = self.cnn2(out)\n",
|
|
|
|
" out = self.batchnorm2(out)\n",
|
|
|
|
" out = self.relu(out)\n",
|
|
|
|
" out = self.maxpool(out)\n",
|
|
|
|
" # we have to flatten the 32 feature maps, output of our last maxpool (100, 1568)\n",
|
|
|
|
" out = out.view(batch_size, 1568)\n",
|
|
|
|
" out = self.fc1(out)\n",
|
|
|
|
" out = self.relu(out)\n",
|
|
|
|
" out = self.dropout(out)\n",
|
|
|
|
" out = self.fc2(out)\n",
|
|
|
|
" \n",
|
|
|
|
" return out\n",
|
|
|
|
" "
|
|
|
|
]
|
2020-05-02 21:40:17 +00:00
|
|
|
}
|
|
|
|
],
|
|
|
|
"metadata": {
|
|
|
|
"kernelspec": {
|
|
|
|
"display_name": "Python 3",
|
|
|
|
"language": "python",
|
|
|
|
"name": "python3"
|
|
|
|
},
|
|
|
|
"language_info": {
|
|
|
|
"codemirror_mode": {
|
|
|
|
"name": "ipython",
|
|
|
|
"version": 3
|
|
|
|
},
|
|
|
|
"file_extension": ".py",
|
|
|
|
"mimetype": "text/x-python",
|
|
|
|
"name": "python",
|
|
|
|
"nbconvert_exporter": "python",
|
|
|
|
"pygments_lexer": "ipython3",
|
|
|
|
"version": "3.7.6"
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"nbformat": 4,
|
|
|
|
"nbformat_minor": 4
|
|
|
|
}
|