{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torchvision.transforms as transforms\n", "import torchvision.datasets as datasets" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "mean_gray = 0.1307\n", "stddev_gray = 0.3081\n", "\n", "# input[channel] = (input[channel]-meaan[channel]) / std[channel]\n", "\n", "transforms = transforms.Compose([transforms.ToTensor(),\n", " transforms.Normalize((mean_gray,),(stddev_gray,))])\n", "\n", "train_dataset = datasets.MNIST(root='./data', \n", " train=True, \n", " transform=transforms, \n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='./data', \n", " train=False, \n", " transform=transforms)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAANr0lEQVR4nO3db6xUdX7H8c9HujwRJFBSvLK0LqvGbGrKNjdYLTE2usTyBPeBm0VtaFy9mKzJqg0tUiMas2raWh+ZNazKotmy2UR2NdBk15JVW2OIV0MFvd31lqALuUIURFcfbJFvH9yDueA9Zy4zZ+YM9/t+JTczc74z53wz4cP5O+fniBCA6e+sphsA0BuEHUiCsANJEHYgCcIOJPEHvVyYbQ79A10WEZ5sekdrdtvX2P617VHb6zqZF4Ducrvn2W3PkPQbSd+QtF/Sq5JWRcRbFZ9hzQ50WTfW7EsljUbE3oj4vaSfSFrZwfwAdFEnYV8o6bcTXu8vpp3E9pDtYdvDHSwLQIe6foAuIjZK2iixGQ80qZM1+wFJiya8/nIxDUAf6iTsr0q60PZXbM+U9G1Jz9XTFoC6tb0ZHxHHbN8m6ReSZkh6MiLerK0zALVq+9RbWwtjnx3ouq5cVAPgzEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBI9HbIZmOiiiy6qrD/22GOV9RtuuKGyPjY2dto9TWes2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgiWlznn327NmV9VmzZlXWjx49Wln/9NNPT7snVFuxYkVl/Yorrqis33zzzZX1Bx98sLR27Nixys9ORx2F3fY+SR9L+kzSsYgYrKMpAPWrY83+VxHxfg3zAdBF7LMDSXQa9pD0S9uv2R6a7A22h2wP2x7ucFkAOtDpZvyyiDhg+48kPW/7fyLipYlviIiNkjZKku3ocHkA2tTRmj0iDhSPhyT9TNLSOpoCUL+2w277bNuzTzyXtFzSnroaA1AvR7S3ZW17scbX5tL47sC/RcT3W3yma5vx999/f2X9rrvuqqyvXbu2sv7II4+cdk+otmzZssr6Cy+80NH8L7744tLa6OhoR/PuZxHhyaa3vc8eEXsl/VnbHQHoKU69AUkQdiAJwg4kQdiBJAg7kMS0+YlrpzZs2FBZ37t3b2nt2WefrbudFM4999ymW0iFNTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJMF59kKrW01v2rSptLZ8+fLKzw4P570jV9X3euedd3Z12dddd11preo209MVa3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSGLanGfft29fV+d/zjnnlNbuu+++ys/eeOONlfUjR4601dOZ4IILLiitLV3KmCK9xJodSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Joe8jmthbWxSGbZ8yYUVlfv359Zb3VfeM7ceutt1bWH3/88a4tu2nnnXdeaa3VkMyLFy/uaNkM2Xyylmt220/aPmR7z4Rp82w/b/vt4nFunc0CqN9UNuN/JOmaU6atk7QjIi6UtKN4DaCPtQx7RLwk6fApk1dK2lw83yzp2pr7AlCzdq+NXxARY8Xz9yQtKHuj7SFJQ20uB0BNOv4hTERE1YG3iNgoaaPU3QN0AKq1e+rtoO0BSSoeD9XXEoBuaDfsz0laXTxfLYkxi4E+1/I8u+0tkq6UNF/SQUkbJP1c0k8l/bGkdyR9KyJOPYg32bwa24yfM2dOZX3nzp2V9arfZbeye/fuyvrVV19dWf/ggw/aXnbTlixZUlrr9v30Oc9+spb77BGxqqR0VUcdAegpLpcFkiDsQBKEHUiCsANJEHYgiWlzK+lWjh49Wll/+eWXK+udnHq75JJLKuuLFi2qrHfz1NvMmTMr62vWrOlo/lXDJqO3WLMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBJpzrO38sorr1TWV69eXVnvxGWXXVZZ37VrV2X98ssvb6smSbNmzaqs33333ZX1Jo2MjFTWp/NQ2O1gzQ4kQdiBJAg7kARhB5Ig7EAShB1IgrADSUybIZu77emnny6tXX/99T3spF5nnVX9//3x48d71En9hobKRx174oknethJb7U9ZDOA6YGwA0kQdiAJwg4kQdiBJAg7kARhB5LgPPsUNTn0cDfZk56S/Vwv/33UbdOmTaW1W265pYed9Fbb59ltP2n7kO09E6bda/uA7V3F34o6mwVQv6lsxv9I0jWTTH8kIpYUf/9eb1sA6tYy7BHxkqTDPegFQBd1coDuNttvFJv5c8veZHvI9rDtM3fHFpgG2g37DyR9VdISSWOSHi57Y0RsjIjBiBhsc1kAatBW2CPiYER8FhHHJf1Q0tJ62wJQt7bCbntgwstvStpT9l4A/aHlfeNtb5F0paT5tvdL2iDpSttLJIWkfZI6G8QbjRkdHa2stzrPvn379sr60aNHS2v33HNP5WdRr5Zhj4hVk0yevr/8B6YpLpcFkiDsQBKEHUiCsANJEHYgCYZsPgMcPlz904R33323tPbww6UXN0qStmzZ0lZPU1X102BOvfUWa3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSILz7FO0d+/e0tpTTz1V+dnFixdX1kdGRirrjz76aGV9zx5uJzCZ5cuXl9bmzi29k5ok6ciRI3W30zjW7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOfZp+ijjz4qrd1000097ARTtXDhwtLazJkze9hJf2DNDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcJ4dXfXhhx+W1sbGxio/OzAwUHc7n3vggQcq62vWVI9CfuzYsTrb6YmWa3bbi2z/yvZbtt+0/b1i+jzbz9t+u3isvhsAgEZNZTP+mKS/i4ivSfoLSd+1/TVJ6yTtiIgLJe0oXgPoUy3DHhFjEfF68fxjSSOSFkpaKWlz8bbNkq7tVpMAOnda++y2z5f0dUk7JS2IiBM7Xe9JWlDymSFJQ+23CKAOUz4ab3uWpGck3R4RJ/0qJCJCUkz2uYjYGBGDETHYUacAOjKlsNv+ksaD/uOI2FpMPmh7oKgPSDrUnRYB1MHjK+WKN9jW+D754Yi4fcL0f5b0QUQ8ZHudpHkR8fct5lW9MKRy6aWXVta3bt1aWV+wYNI9x1rMmTOnsv7JJ590bdmdighPNn0q++x/KelvJO22vauYtl7SQ5J+avs7kt6R9K06GgXQHS3DHhH/JWnS/ykkXVVvOwC6hctlgSQIO5AEYQeSIOxAEoQdSKLlefZaF8Z5dpyGwcHqiy63bdtWWZ8/f37by77qquoTTS+++GLb8+62svPsrNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAluJY2+NTw8XFm/4447Kutr164trW3fvr2jZZ+JWLMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBL8nh2YZvg9O5AcYQeSIOxAEoQdSIKwA0kQdiAJwg4k0TLsthfZ/pXtt2y/aft7xfR7bR+wvav4W9H9dgG0q+VFNbYHJA1ExOu2Z0t6TdK1Gh+P/XcR8S9TXhgX1QBdV3ZRzVTGZx+TNFY8/9j2iKSF9bYHoNtOa5/d9vmSvi5pZzHpNttv2H7S9tySzwzZHrY9/e7zA5xBpnxtvO1Zkl6U9P2I2Gp7gaT3JYWk+zW+qX9Ti3mwGQ90Wdlm/JTCbvtLkrZJ+kVE/Osk9fMlbYuIP20xH8IOdFnbP4SxbUlPSBqZGPTiwN0J35S0p9MmAXTPVI7GL5P0n5J2SzpeTF4vaZWkJRrfjN8naU1xMK9qXqzZgS7raDO+LoQd6D5+zw4kR9iBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUii5Q0na/a+pHcmvJ5fTOtH/dpbv/Yl0Vu76uztT8oKPf09+xcWbg9HxGBjDVTo1976tS+J3trVq97YjAeSIOxAEk2HfWPDy6/Sr731a18SvbWrJ701us8OoHeaXrMD6BHCDiTRSNhtX2P717ZHba9roocytvfZ3l0MQ93o+HTFGHqHbO+ZMG2e7edtv108TjrGXkO99cUw3hXDjDf63TU9/HnP99ltz5D0G0nfkLRf0quSVkXEWz1tpITtfZIGI6LxCzBsXyHpd5KeOjG0lu1/knQ4Ih4q/qOcGxH/0Ce93avTHMa7S72VDTP+t2rwu6tz+PN2NLFmXyppNCL2RsTvJf1E0soG+uh7EfGSpMOnTF4paXPxfLPG/7H0XElvfSEixiLi9eL5x5JODDPe6HdX0VdPNBH2hZJ+O+H1fvXXeO8h6Ze2X7M91HQzk1gwYZit9yQtaLKZSbQcxruXThlmvG++u3aGP+8UB+i+aFlE/Lmkv5b03WJztS/F+D5YP507/YGkr2p8DMAxSQ832UwxzPgzkm6PiI8m1pr87ibpqyffWxNhPyBp0YTXXy6m9YWIOFA8HpL0M43vdvSTgydG0C0eDzXcz+ci4mBEfBYRxyX9UA1+d8Uw489I+nFEbC0mN/7dTdZXr763JsL+qqQLbX/F9kxJ35b0XAN9fIHts4sDJ7J9tqTl6r+hqJ+TtLp4vlrSsw32cpJ+Gca7bJhxNfzdNT78eUT0/E/SCo0fkf9fSf/YRA8lfS2W9N/F35tN9yZpi8Y36/5P48c2viPpDyXtkPS2pP+QNK+Penta40N7v6HxYA001NsyjW+ivyFpV/G3ounvrqKvnnxvXC4LJMEBOiAJwg4kQdiBJAg7kARhB5Ig7EAShB1I4v8B1DNMfBo+lxwAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "random_img = train_dataset[20][0].numpy() * stddev_gray + mean_gray\n", "plt.imshow(random_img.reshape(28,28), cmap='gray')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4\n" ] } ], "source": [ "print(train_dataset[20][1])" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "batch_size = 100\n", "\n", "train_load = torch.utils.data.DataLoader(dataset=train_dataset,\n", " batch_size=batch_size,\n", " shuffle=True)\n", "\n", "test_load = torch.utils.data.DataLoader(dataset=test_dataset,\n", " batch_size=batch_size,\n", " shuffle=True)\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "100" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(test_load)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CNN(nn.Module):\n", " def __init__(self):\n", " super(CNN,self).__init__()\n", " # Same padding means that input_size = output_size\n", " # same_padding = (filter_size - 1) / 2\n", " self.cnn1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=(3,3), stride=1, padding=1)\n", " # The output from this layer has size:\n", " # [(input_size - filter_size + 2*(padding)) / stride] + 1\n", " self.batchnorm1 = nn.BatchNorm2d(8)\n", " self.relu = nn.ReLU()\n", " self.maxpool = nn.MaxPool2d(kernel_size=(2,2))\n", " # Pooling output size is 28/2 = 14 (downsampling)\n", " # Same padding size is (5 - 1)/2 = 2\n", " self.cnn2 = nn.Conv2d(in_channels=8, out_channels=32, kernel_size=(5,5), stride=1, padding=2)\n", " self.batchnorm2 = nn.BatchNorm2d(32)\n", " # Pooling output size is 14/2 = 7 (downsampling)\n", " # We have to flatten the output channels 32*7*7 = 1568\n", " self.fc1 = nn.Linear(1568,600)\n", " self.dropout = nn.Dropout(p=0.5)\n", " self.fc2 = nn.Linear(600,10)\n", " \n", " def forward(self,x):\n", " out = self.cnn1(x)\n", " out = self.batchnorm1(out)\n", " out = self.relu(out)\n", " out = self.maxpool(out)\n", " out = self.cnn2(out)\n", " out = self.batchnorm2(out)\n", " out = self.relu(out)\n", " out = self.maxpool(out)\n", " # we have to flatten the 32 feature maps, output of our last maxpool (100, 1568)\n", " out = out.view(batch_size, 1568)\n", " out = self.fc1(out)\n", " out = self.relu(out)\n", " out = self.dropout(out)\n", " out = self.fc2(out)\n", " \n", " return out\n", " " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }