pytorch-stuff/CNN - MNIST.ipynb

561 lines
73 KiB
Plaintext
Raw Normal View History

2020-05-02 21:40:17 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"from PIL import Image\n",
"# the previous import is to process my own images\n",
2020-05-02 21:40:17 +00:00
"import torch\n",
"import torch.nn as nn\n",
"import torchvision.transforms as transforms\n",
"import torchvision.datasets as datasets"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"mean_gray = 0.1307\n",
"stddev_gray = 0.3081\n",
"\n",
"# input[channel] = (input[channel]-meaan[channel]) / std[channel]\n",
"\n",
"transforms_ori = transforms.Compose([transforms.ToTensor(),\n",
2020-05-02 21:40:17 +00:00
" transforms.Normalize((mean_gray,),(stddev_gray,))])\n",
"\n",
"train_dataset = datasets.MNIST(root='./data', \n",
" train=True, \n",
" transform=transforms_ori, \n",
2020-05-02 21:40:17 +00:00
" download=True)\n",
"\n",
"test_dataset = datasets.MNIST(root='./data', \n",
" train=False, \n",
" transform=transforms_ori)\n",
"\n",
"# The next code is to transform the image\n",
"transforms_photo = transforms.Compose([transforms.Resize((28,28)),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((mean_gray,),(stddev_gray,))])\n"
2020-05-02 21:40:17 +00:00
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7ff75d0d37d0>"
2020-05-02 21:40:17 +00:00
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAANr0lEQVR4nO3db6xUdX7H8c9HujwRJFBSvLK0LqvGbGrKNjdYLTE2usTyBPeBm0VtaFy9mKzJqg0tUiMas2raWh+ZNazKotmy2UR2NdBk15JVW2OIV0MFvd31lqALuUIURFcfbJFvH9yDueA9Zy4zZ+YM9/t+JTczc74z53wz4cP5O+fniBCA6e+sphsA0BuEHUiCsANJEHYgCcIOJPEHvVyYbQ79A10WEZ5sekdrdtvX2P617VHb6zqZF4Ducrvn2W3PkPQbSd+QtF/Sq5JWRcRbFZ9hzQ50WTfW7EsljUbE3oj4vaSfSFrZwfwAdFEnYV8o6bcTXu8vpp3E9pDtYdvDHSwLQIe6foAuIjZK2iixGQ80qZM1+wFJiya8/nIxDUAf6iTsr0q60PZXbM+U9G1Jz9XTFoC6tb0ZHxHHbN8m6ReSZkh6MiLerK0zALVq+9RbWwtjnx3ouq5cVAPgzEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBI9HbIZmOiiiy6qrD/22GOV9RtuuKGyPjY2dto9TWes2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgiWlznn327NmV9VmzZlXWjx49Wln/9NNPT7snVFuxYkVl/Yorrqis33zzzZX1Bx98sLR27Nixys9ORx2F3fY+SR9L+kzSsYgYrKMpAPWrY83+VxHxfg3zAdBF7LMDSXQa9pD0S9uv2R6a7A22h2wP2x7ucFkAOtDpZvyyiDhg+48kPW/7fyLipYlviIiNkjZKku3ocHkA2tTRmj0iDhSPhyT9TNLSOpoCUL+2w277bNuzTzyXtFzSnroaA1AvR7S3ZW17scbX5tL47sC/RcT3W3yma5vx999/f2X9rrvuqqyvXbu2sv7II4+cdk+otmzZssr6Cy+80NH8L7744tLa6OhoR/PuZxHhyaa3vc8eEXsl/VnbHQHoKU69AUkQdiAJwg4kQdiBJAg7kMS0+YlrpzZs2FBZ37t3b2nt2WefrbudFM4999ymW0iFNTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJMF59kKrW01v2rSptLZ8+fLKzw4P570jV9X3euedd3Z12dddd11preo209MVa3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSGLanGfft29fV+d/zjnnlNbuu+++ys/eeOONlfUjR4601dOZ4IILLiitLV3KmCK9xJodSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Joe8jmthbWxSGbZ8yYUVlfv359Zb3VfeM7ceutt1bWH3/88a4tu2nnnXdeaa3VkMyLFy/uaNkM2Xyylmt220/aPmR7z4Rp82w/b/vt4nFunc0CqN9UNuN/JOmaU6atk7QjIi6UtKN4DaCPtQx7RLwk6fApk1dK2lw83yzp2pr7AlCzdq+NXxARY8Xz9yQtKHuj7SFJQ20uB0BNOv4hTERE1YG3iNgoaaPU3QN0AKq1e+rtoO0BSSoeD9XXEoBuaDfsz0laXTxfLYkxi4E+1/I8u+0tkq6UNF/SQUkbJP1c0k8l/bGkdyR9KyJOPYg32bwa24yfM2dOZX3nzp2V9arfZbeye/fuyvrVV19dWf/ggw/aXnbTlixZUlrr9v30Oc9+spb77BGxqqR0VUcdAegpLpcFkiDsQBKEHUiCsANJEHYgiWlzK+lWjh49Wll/+eWXK+udnHq75JJLKuuLFi2qrHfz1NvMmTMr62vWrOlo/lXDJqO3WLMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBJpzrO38sorr1TWV69eXVnvxGWXXVZZ37VrV2X98ssvb6smSbNmzaqs33333ZX1Jo2MjFTWp/NQ2O1gzQ4kQdiBJAg7kARhB5Ig7EAShB1IgrADSUybIZu77emnny6tXX/99T3spF5nnVX9//3x48d71En9hobKRx174oknethJb7U9ZDOA6YGwA0kQdiAJwg4kQdiBJAg7kARhB5LgPPsUNTn0cDfZk56S/Vwv/33UbdOmTaW1W265pYed9Fbb59ltP2n7kO09E6bda/uA7V3F34o6mwVQv6lsxv9I0jWTTH8kIpYUf/9eb1sA6tYy7BHxkqTDPegFQBd1coDuNttvFJv5c8veZHvI9rDtM3fHFpgG2g37DyR9VdISSWOSHi57Y0RsjIjBiBhsc1kAatBW2CPiYER8FhHHJf1Q0tJ62wJQt7bCbntgwstvStpT9l4A/aHlfeNtb5F0paT5tvdL2iDpSttLJIWkfZI6G8QbjRkdHa2stzrPvn379sr60aNHS2v33HNP5WdRr5Zhj4hVk0yevr/8B6YpLpcFkiDsQBKEHUiCsANJEHYgCYZsPgMcPlz904R33323tPbww6UXN0qStmzZ0lZPU1X102BOvfUWa3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSILz7FO0d+/e0tpTTz1V+dnFixdX1kdGRirrjz76aGV9zx5uJzCZ5cuXl9bmzi29k5ok6ciRI3W30zjW7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOfZp+ijjz4qrd1000097ARTtXDhwtLazJkze9hJf2DNDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJcJ4dXfXhhx+W1sbGxio/OzAwUHc7n3vggQcq62vWVI9CfuzYsTrb6YmWa3bbi2z/yvZbtt+0/b1i+jzbz9t+u3isvhsAgEZNZTP+mKS/i4ivSfoLSd+1/TVJ6yTtiIgLJe0oXgPoUy3DHhFjEfF68fxjSSOSFkpaKWlz8bbNkq7tVpMAOnda++y2z5f0dUk7JS2IiBM7Xe9JWlDymSFJQ+23CKAOUz4ab3uWpGck3R4RJ/0qJCJCUkz2uYjYGBGDETHYUacAOjKlsNv+ksaD/uOI2FpMPmh7oKgPSDrUnRYB1MHjK+WKN9jW+D754Yi4fcL0f5b0QUQ8ZHudpHkR8fct5lW9MKRy6aWXVta3bt1aWV+wYNI9x1rMmTOnsv7JJ590bdmdighPNn0q++x/KelvJO22vauYtl7SQ5J+avs7kt6R9K06GgXQHS3DHhH/JWnS/ykkXVVvOwC6hctlgSQIO5AEYQeSIOxAEoQdSKLlefZaF8Z5dpyGwcHqiy63bdtWWZ8/f37by77qquoTTS+++GLb8+62svPsrNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAluJY2+NTw8XFm/4447Kutr164trW3fvr2jZZ+JWLMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBL8nh2YZvg9O5AcYQeSIOxAEoQdSIKwA0kQdiAJwg4k0TLsthfZ/pXtt2y/aft7xfR7bR+wvav4W9H9dgG0q+VFNbYHJA1ExOu2Z0t6TdK1Gh+P/XcR8S9TXhgX1QBdV3ZRzVTGZx+TNFY8/9j2iKSF9bYHoNtOa5/d9vmSvi5pZzHpNttv2H7S9tySzwzZHrY9/e7zA5xBpnxtvO1Zkl6U9P2I2Gp7gaT3JYWk+zW+qX9Ti3mwGQ90Wdlm/JTCbvtLkrZJ+kVE/Osk9fMlbYuIP20xH8IOdFnbP4SxbUlPSBqZGPTiwN0J35S0p9MmAXTPVI7GL5P0n5J2SzpeTF4vaZWkJRrfjN8naU1xMK9qXqzZgS7raDO+LoQd6D5+zw4kR9iBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUii5Q0na/a+pHcmvJ5fTOtH/dpbv/Yl0Vu76uztT8oKPf09+xcWbg9HxGBjDVTo1976tS+J3trVq97YjAeSIOxAEk2HfWPDy6/Sr731a18SvbW
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"random_img = train_dataset[20][0].numpy() * stddev_gray + mean_gray\n",
"plt.imshow(random_img.reshape(28,28), cmap='gray')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
2020-05-03 18:26:33 +00:00
"name": "stdout",
"output_type": "stream",
"text": [
"4\n"
]
2020-05-02 21:40:17 +00:00
}
],
"source": [
"print(train_dataset[20][1])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 100\n",
"\n",
"train_load = torch.utils.data.DataLoader(dataset=train_dataset,\n",
" batch_size=batch_size,\n",
" shuffle=True)\n",
"\n",
"test_load = torch.utils.data.DataLoader(dataset=test_dataset,\n",
" batch_size=batch_size,\n",
" shuffle=True)\n"
]
},
{
"cell_type": "code",
2020-05-03 18:26:33 +00:00
"execution_count": 6,
2020-05-02 21:40:17 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"100"
]
},
2020-05-03 18:26:33 +00:00
"execution_count": 6,
2020-05-02 21:40:17 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(test_load)"
]
},
{
"cell_type": "code",
"execution_count": 7,
2020-05-02 21:40:17 +00:00
"metadata": {},
"outputs": [],
2020-05-03 18:26:33 +00:00
"source": [
"class CNN(nn.Module):\n",
" def __init__(self):\n",
" super(CNN,self).__init__()\n",
" # Same padding means that input_size = output_size\n",
" # same_padding = (filter_size - 1) / 2\n",
" self.cnn1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=(3,3), stride=1, padding=1)\n",
" # The output from this layer has size:\n",
" # [(input_size - filter_size + 2*(padding)) / stride] + 1\n",
" self.batchnorm1 = nn.BatchNorm2d(8)\n",
" self.relu = nn.ReLU()\n",
" self.maxpool = nn.MaxPool2d(kernel_size=(2,2))\n",
" # Pooling output size is 28/2 = 14 (downsampling)\n",
" # Same padding size is (5 - 1)/2 = 2\n",
" self.cnn2 = nn.Conv2d(in_channels=8, out_channels=32, kernel_size=(5,5), stride=1, padding=2)\n",
" self.batchnorm2 = nn.BatchNorm2d(32)\n",
" # Pooling output size is 14/2 = 7 (downsampling)\n",
" # We have to flatten the output channels 32*7*7 = 1568\n",
" self.fc1 = nn.Linear(1568,600)\n",
" self.dropout = nn.Dropout(p=0.5)\n",
" self.fc2 = nn.Linear(600,10)\n",
" \n",
" def forward(self,x):\n",
" out = self.cnn1(x)\n",
" out = self.batchnorm1(out)\n",
" out = self.relu(out)\n",
" out = self.maxpool(out)\n",
" out = self.cnn2(out)\n",
" out = self.batchnorm2(out)\n",
" out = self.relu(out)\n",
" out = self.maxpool(out)\n",
" # we have to flatten the 32 feature maps, output of our last maxpool (100, 1568)\n",
" out = out.view(-1, 1568) # <- setting the number of rows to -1 is important, because \n",
" out = self.fc1(out) # when we try to predict we would require to use a size 100 (batch size)\n",
" out = self.relu(out) # tensor but -1 infers the size.\n",
2020-05-03 18:26:33 +00:00
" out = self.dropout(out)\n",
" out = self.fc2(out)\n",
" \n",
" return out\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"model = CNN()\n",
"\n",
"CUDA = torch.cuda.is_available()\n",
"\n",
"if CUDA:\n",
" model = model.cuda()\n",
" \n",
"loss_fn = nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
}
],
"source": [
"print(CUDA)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"For one iteration this is what happens:\n",
"Input shape: torch.Size([100, 1, 28, 28])\n",
"Labels shape: torch.Size([100])\n",
"Output shape: torch.Size([100, 10])\n",
"Predicted shape: torch.Size([100, 1, 28, 28])\n",
"Predicted tensor: \n",
"tensor([9, 4, 9, 4, 0, 7, 3, 7, 8, 4, 0, 0, 4, 9, 9, 8, 4, 0, 6, 4, 9, 0, 3, 8,\n",
" 8, 8, 0, 0, 8, 9, 3, 8, 4, 0, 4, 8, 9, 0, 4, 4, 0, 9, 0, 9, 9, 7, 4, 4,\n",
" 8, 0, 0, 8, 8, 0, 0, 4, 9, 9, 9, 4, 8, 4, 0, 4, 0, 8, 0, 0, 4, 8, 4, 4,\n",
" 7, 8, 5, 4, 8, 9, 9, 8, 7, 8, 0, 9, 0, 4, 0, 9, 9, 8, 5, 0, 4, 9, 4, 7,\n",
" 9, 4, 4, 9], device='cuda:0')\n"
]
}
],
"source": [
"# Understand what is happening\n",
"iteration = 0\n",
"correct = 0\n",
"\n",
"for i,(inputs, labels) in enumerate(train_load):\n",
" \n",
" if CUDA:\n",
" inputs = inputs.cuda()\n",
" labels = labels.cuda()\n",
" \n",
" print('For one iteration this is what happens:')\n",
" print('Input shape: ', inputs.shape)\n",
" print('Labels shape: ', labels.shape)\n",
" output = model(inputs)\n",
" print('Output shape: ', output.shape)\n",
" _, predicted = torch.max(output, dim=1)\n",
" print('Predicted shape: ', inputs.shape)\n",
" print('Predicted tensor: ')\n",
" print(predicted)\n",
" correct += (predicted==labels).sum()\n",
" break\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10, Training loss: 0.574, Training accuracy: 0.840, Testing loss: 3.446, Testing accuracy: 0.974\n",
"Epoch 2/10, Training loss: 0.151, Training accuracy: 0.956, Testing loss: 0.907, Testing accuracy: 0.984\n",
"Epoch 3/10, Training loss: 0.105, Training accuracy: 0.971, Testing loss: 0.629, Testing accuracy: 0.987\n",
"Epoch 4/10, Training loss: 0.085, Training accuracy: 0.975, Testing loss: 0.511, Testing accuracy: 0.989\n",
"Epoch 5/10, Training loss: 0.080, Training accuracy: 0.978, Testing loss: 0.477, Testing accuracy: 0.988\n",
"Epoch 6/10, Training loss: 0.068, Training accuracy: 0.981, Testing loss: 0.407, Testing accuracy: 0.990\n",
"Epoch 7/10, Training loss: 0.064, Training accuracy: 0.982, Testing loss: 0.385, Testing accuracy: 0.990\n",
"Epoch 8/10, Training loss: 0.062, Training accuracy: 0.982, Testing loss: 0.369, Testing accuracy: 0.989\n",
"Epoch 9/10, Training loss: 0.058, Training accuracy: 0.983, Testing loss: 0.346, Testing accuracy: 0.986\n",
"Epoch 10/10, Training loss: 0.055, Training accuracy: 0.984, Testing loss: 0.331, Testing accuracy: 0.988\n"
]
}
],
"source": [
"# Training of the CNN\n",
"num_epochs = 10\n",
"train_loss = list()\n",
"train_accuracy = list()\n",
"test_loss = list()\n",
"test_accuracy = list()\n",
"\n",
"\n",
"for epoch in range(num_epochs):\n",
" correct = 0\n",
" iterations = 0\n",
" iter_loss = 0.0\n",
" \n",
" # When batch normalization and dropout are used\n",
" # we have to specify our model that we are training\n",
" model.train()\n",
" \n",
" for i,(inputs, labels) in enumerate(train_load):\n",
" \n",
" if CUDA:\n",
" inputs = inputs.cuda()\n",
" labels = labels.cuda()\n",
" \n",
" outputs = model(inputs)\n",
" loss = loss_fn(outputs,labels)\n",
" iter_loss += loss.item()\n",
" \n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" _, predicted = torch.max(outputs, dim=1)\n",
" correct += (predicted==labels).sum().item()\n",
" iterations += 1\n",
" \n",
" train_loss.append(iter_loss/iterations)\n",
" train_accuracy.append(correct/len(train_dataset))\n",
" \n",
" # Testing phase\n",
" testing_loss = 0.0\n",
" correct = 0\n",
" iterations = 0\n",
" \n",
" model.eval()\n",
" \n",
" for i,(inputs, labels) in enumerate(test_load):\n",
" \n",
" if CUDA:\n",
" inputs = inputs.cuda()\n",
" labels = labels.cuda()\n",
" \n",
" outputs = model(inputs)\n",
" loss = loss_fn(outputs,labels)\n",
" testing_loss += loss.item()\n",
" \n",
" _, predicted = torch.max(outputs, dim=1)\n",
" correct += (predicted==labels).sum().item()\n",
" iterations += 1\n",
" \n",
" test_loss.append(iter_loss/iterations)\n",
" test_accuracy.append(correct/len(test_dataset))\n",
" \n",
" print('Epoch {}/{}, Training loss: {:.3f}, Training accuracy: {:.3f}, Testing loss: {:.3f}, Testing accuracy: {:.3f}'.format(\n",
" epoch+1,num_epochs,train_loss[-1],train_accuracy[-1],test_loss[-1],test_accuracy[-1]))\n",
"\n",
"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlMAAAI/CAYAAABTd1zJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nOzdeZScZZ33//dV1WvSna0qQEgICXSxhC0MYVcmMCogKqjogCggKouOqKM/RX0YcX7PPDqPzui4IiKijgM4MD9ExV0REAUChh0EwpYQIOnsW6/X74+qXtLppLfqvmt5v86p03Uvddc33edwPlz3dX3vEGNEkiRJo5NKugBJkqRyZpiSJEkaA8OUJEnSGBimJEmSxsAwJUmSNAaGKUmSpDGoSeqLs9lsnDdvXlJfL0mSNGz33Xff6hjjzMGOJRam5s2bx5IlS5L6ekmSpGELITy3s2Pe5pMkSRoDw5QkSdIYGKYkSZLGILE5U5IkaXsdHR0sX76cbdu2JV1K1WpoaGDOnDnU1tYO+zOGKUmSSsTy5ctpbm5m3rx5hBCSLqfqxBhpbW1l+fLlzJ8/f9if8zafJEklYtu2bWQyGYNUQkIIZDKZEY8MGqYkSSohBqlkjeb3b5iSJEkAtLa2snDhQhYuXMgee+zB7Nmze7fb29t3+dklS5Zw6aWXDvkdxx13XFFqve2223jDG95QlGuNlXOmJEkSAJlMhqVLlwJwxRVX0NTUxMc+9rHe452dndTUDB4dFi1axKJFi4b8jrvuuqs4xZYQR6YkSdJOnX/++Vx88cUcffTRfPzjH+eee+7h2GOP5fDDD+e4447jiSeeALYfKbriiiu44IILWLx4Mfvssw9f+cpXeq/X1NTUe/7ixYs588wzOeCAAzjnnHOIMQJw6623csABB3DEEUdw6aWXDjkCtWbNGs444wwOPfRQjjnmGB588EEA/vCHP/SOrB1++OFs3LiRlStXcsIJJ7Bw4UIOPvhg7rjjjjH/jhyZkiRJu7R8+XLuuusu0uk0GzZs4I477qCmpobf/OY3fOpTn+Kmm27a4TOPP/44v//979m4cSP7778/l1xyyQ7tBv7yl7/wyCOPsOeee3L88cfzxz/+kUWLFnHRRRdx++23M3/+fM4+++wh6/vMZz7D4Ycfzs0338zvfvc7zj33XJYuXcoXv/hFvv71r3P88cezadMmGhoauOqqqzj55JP59Kc/TVdXF1u2bBnz72fIMBVCaABuB+oL598YY/zMgHPOB74ArCjs+lqM8eoxVydJUpX67E8e4dEXNxT1mgv2nMJn3njQiD/3tre9jXQ6DcD69es577zzePLJJwkh0NHRMehnTjvtNOrr66mvr2e33Xbj5ZdfZs6cOdudc9RRR/XuW7hwIc8++yxNTU3ss88+va0Jzj77bK666qpd1nfnnXf2BrqTTjqJ1tZWNmzYwPHHH88//uM/cs455/CWt7yFOXPmcOSRR3LBBRfQ0dHBGWecwcKFC0f8+xhoOLf52oCTYoyHAQuBU0IIxwxy3g0xxoWFl0FKkqQKMXny5N73l19+OSeeeCIPP/wwP/nJT3baRqC+vr73fTqdprOzc1TnjMVll13G1VdfzdatWzn++ON5/PHHOeGEE7j99tuZPXs2559/Pt///vfH/D1DjkzF/A3MTYXN2sIrjvmbJUnSTo1mBGkirF+/ntmzZwNw7bXXFv36+++/P8uWLePZZ59l3rx53HDDDUN+5tWvfjU//OEPufzyy7ntttvIZrNMmTKFp59+mkMOOYRDDjmEe++9l8cff5zGxkbmzJnD+973Ptra2rj//vs599xzx1TzsCaghxDSIYSlwCvAr2OMdw9y2ltDCA+GEG4MIew1pqokSVJJ+vjHP84nP/lJDj/88KKPJAE0NjbyjW98g1NOOYUjjjiC5uZmpk6dusvPXHHFFdx3330ceuihXHbZZXzve98D4Mtf/jIHH3wwhx56KLW1tZx66qncdtttHHbYYRx++OHccMMNfOhDHxpzzaFn5vywTg5hGvD/AR+MMT7cb38G2BRjbAshXAT8fYzxpEE+fyFwIcDcuXOPeO6558ZavyRJFeOxxx7jwAMPTLqMxG3atImmpiZijHzgAx8gl8vxkY98ZMK+f7C/QwjhvhjjoL0fRtQaIca4Dvg9cMqA/a0xxrbC5tXAETv5/FUxxkUxxkUzZ84cyVdLkqQq8e1vf5uFCxdy0EEHsX79ei666KKkS9ql4azmmwl0xBjXhRAagdcC/zrgnFkxxpWFzTcBjxW9UkmSVBU+8pGPTOhI1FgNp8/ULOB7IYQ0+ZGsH8UYfxpC+GdgSYzxFuDSEMKbgE5gDXD+eBUsSZJUSoazmu9B4PBB9v9Tv/efBD5Z3NIkSZJKn4+TkSRJGgPDlCRJ0hhUbph6+RH46iJ49s6kK5EkqSy0trb2Phh4jz32YPbs2b3b7e3tQ37+tttu46677urdvvLKK4vSYRxg8eLFLFmypCjXKrbKfdBx43RofRJWPQ7zXpV0NZIklbxMJsPSpUuBfCPMpqYmPvaxjw3787fddhtNTU0cd9xxAFx88cXjUmepqdyRqeZZUNcEq59KuhJJksrWfffdx9/+7d9yxBFHcPLJJ7NyZb4T0le+8hUWLFjAoYceyllnncWzzz7LlVdeyZe+9CUWLlzIHXfcwRVXXMEXv/hFID+y9IlPfIKjjjqK/fbbjzvuuAOALVu28Pa3v50FCxbw5je/maOPPnrIEajrrruOQw45hIMPPphPfOITAHR1dXH++edz8MEHc8ghh/ClL31p0DrHQ+WOTIUAmX3zo1OSJGnEYox88IMf5Mc//jEzZ87khhtu4NOf/jTXXHMNn//853nmmWeor69n3bp1TJs2jYsvvni70azf/va3212vs7OTe+65h1tvvZXPfvaz/OY3v+Eb3/gG06dP59FHH+Xhhx9m4cKFu6zpxRdf5BOf+AT33Xcf06dP53Wvex0333wze+21FytWrODhh/MPaFm3bh3ADnWOh8oNUwCZHCy/N+kqJEkauZ9fBi89VNxr7nEInPr5YZ/e1tbGww8/zGtf+1ogP/oza9YsAA499FDOOecczjjjDM4444xhXe8tb3kLAEcccQTPPvssAHfeeWfv8/F6nqO3K/feey+LFy+m50kq55xzDrfffjuXX345y5Yt44Mf/CCnnXYar3vd60Zd50hV7m0+gGwO1j0PHduSrkSSpLITY+Sggw5i6dKlLF26lIceeohf/epXAPzsZz/jAx/4APfffz9HHnnksB56XF9fD0A6nS76Q5KnT5/OAw88wOLFi7nyyit573vfO+o6R6rCR6ZagAhrlsHuC5KuRpKk4RvBCNJ4qa+vZ9WqVfzpT3/i2GOPpaOjg7/+9a8ceOCBvPDCC5x44om86lWv4vrrr2fTpk00NzezYcOGEX3H8ccfz49+9CNOPPFEHn30UR56aNejcUcddRSXXnopq1evZvr06Vx33XV88IMfZPXq1dTV1fHWt76V/fffn3e+8510d3cPWue0adPG8mvZQRWEKfLzpgxTkiSNSCqV4sYbb+TSSy9l/fr1dHZ28uEPf5j99tuPd77znaxfv54YI5deeinTpk3jjW98I2eeeSY//vGP+epXvzqs73j/+9/Peeedx4IFCzjggAM46KCDmDp16k7PnzVrFp///Oc58cQTiTFy2mmncfrpp/PAAw/w7ne/m+7ubgA+97nP0dXVNWidxRZijEW/6HAsWrQojnu/iLZN8LnZcNLlcMLwl3ZKkpSExx57jAMPPDDpMiZUV1cXHR0dNDQ08PTTT/Oa17yGJ554grq6usRqGuzvEEK4L8a4aLDzK3tkqr4JmveEVtsjSJJUirZs2cKJJ55IR0cHMUa+8Y1vJBqkRqOywxRAtgVW2x5BkqRS1NzcXLKdzYerslfzQb49QuuTkNDtTEmSVNkqP0xlc7BtPWxenXQlkiQNKam5zMobze+/8sNUJpf/aSd0SVKJa2hooLW11UCVkBgjra2tNDQ0jOhz1TFnCvLzpvY+Ltl
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Plotting the loss\n",
"\n",
"f = plt.figure(figsize=(10,10))\n",
"plt.plot(train_loss, label='Training loss')\n",
"plt.plot(test_loss, label='Testing loss')\n",
"plt.legend()\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlkAAAI/CAYAAABEVcwAAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nOzdeXicZ33v/89XM9pHkrXZji3HkuIlOCROHMWBAE1YayglTaAUk0JC+JHAOYFfe0hpKDSloZyUlnKgp7SnCYSw9CQFWkp61ZAmISmUQGxnbzbb0Si27MTWSLIljZYZzdznj2ckjWTZHkszemZ5v65rLs2zjPwdOdF8fD/3873NOScAAABkV5nfBQAAABQjQhYAAEAOELIAAABygJAFAACQA4QsAACAHCBkAQAA5EDQ7wLmamlpce3t7X6XAQAAcEqPPvpoxDnXOt+xvAtZ7e3t2r17t99lAAAAnJKZvXSiY1wuBAAAyAFCFgAAQA4QsgAAAHKAkAUAAJADhCwAAIAcIGQBAADkACELAAAgBwhZAAAAOUDIAgAAyAFCFgAAQA4QsgAAAHKAkAUAAJADhCwAAIAcIGQBAADkACELAAAgBwhZAAAAOUDIAgAAyAFCFgAAQA4QsgAAAHKAkAUAAJADhCwAAIAcIGQBAADkQNDvAgCg6IwPSUOHpMlxySWkZDL1NTHn63z7k/Ocl639p6jjuNec4nsEyqWKkFRZJ1WmvlbUzd6eb9/UdrDC778pIKcIWQBwOhKT0vDL0rHe1ONA2vPUY+KYD4WZVBaQLJD2tWzO9sn2l81zXkAKBk/8+kRcig1LoxFpsEeaGJZiI94jE4GKVOgKSZX1aSEsLaDN2k6dN982gQ15iJAF5IPJmDR0cPYHtUmqXS7Vtkqh5VJti7ddUeN3tcVt7OjJA9TwIW+EJ111o9TQJjWuldpf5z2vXy0FqxYYeE6wvyx44jBk5s/Paz7JhBSLzoSuieHZj9iINDEkTYwcf87IESnWndoekeLRzP7MQOWJR9OmQ9ncIHeCYBcoz+3PByWDkAXkmnPS6MCcD+w5H94jhyW5zL5fea0UavXCV+3y2c9rW1KBLLWvujG/Pnz9lohnMAo1NPs1ZeVSw2qpYY3U8QYvQE0/1nhhqjLkz/vJV2UBqareeyxWMpEWwqZC2fAptkdmAlv/izP74qOZ/ZnBKi90VTdKzeuklvVS60apZYP3qF62+PeFkkDIAhZrcuL4Uai5H95zf7kHq2c+qNe/1fuwTv/wrl/tnRftm/0YOSJFI1L0iLc92CP17vIu18wdXZG8kY/a1plRsNrWOaEsfbu1sP8F75w0fvTkfw/DL88zCtWUGoXqkNrTQtSyM72vtcu9USX4oywgVTV4j8VKTM4EsLkjbLOC3JC3HY1I/fukFx+QErGZ7xNaMRO4WjZIramv9av5Rw1mIWQBJ+OcNNo/z4jHgTmjUHOEVngf0MtfJa1/28yox9TXmqbMfhkvW+M9TiWZ8EbLon2pABZJBbI525G93vbk+Pzfp7rxxKNix122rF3aD5RE3JtMPv13sP/4Uai5c4ECFd4HX0Ob1HHp8aNQDau994HSEAh6o1CnOxKVmJSOviRF9niPvj1S5AXp6R/Mnn9XEfJGvWYFsI1egGfO2NKaGPZ+X8RHpVUX+FaGOZfhJYol0tXV5Xbv3u13GSgV8fF5RqHmfHjPDSTBai/4zPqwnjMKFaz05/1kwjnvF9DJRshG+mYC2vgJJnEHq09+qXL6+XIvvJ1sNMg5aWzw1KNQcy+p1rSc4O8h9by2lVEo5I5zqX+8pEJXZK/Ul/o61DtzXlnQC1rTo15Tlx7XZ+eSaqmZGPEC1FBv6h9eB73f40MHU88PzYTflo3SDTtzWo6ZPeqc65r3GCELRcs5LzicbBQqeuT414VWpi4XrZn/w7vU5jlNTqQFsPQRsvkCWsS7vX8uC6RGwNJGxYKVs0em5k5wDlScOMg2nCnVr+ImAOSviWEvbEX2pgJYagRs4EUpOTlzXt0Zs0e9WtZ7waBuZWn9npkSi878Xhg6mPb80EyQmu8ffrWt3j9w61d7I9T1q6T61M0oa7bmtGRCFopTfMz7H+/o3MtGB2b+B507ClVeM/+Ix/Qo1Kr8HoXKd8mkNyI1PSJ2khGy+HhqQvl8IWqNN0rFKBSKTSLuzaWM7EmNeqVdgowNz5xXWT8TuNIn3jd2eJc9C1Fs9PgRp7mjUeNHj39dTUsqOE09Vs1cNahf5fvvbUIWCtf0L6S9Un/qX4X9+7w7ho4bhTLvX4XHzbtJ2y61USgAhcE5afiVOZcdUwFs+OWZ88rKpabO2ZcdWzdIzev9vcs1Njp7tGk6SKWNRs0boJrnjEDNGY2qWyWVVy39+zkNJwtZBRqHUVSmLutNh6i9UmSf93WwZ/bQem2r90tl47bU3V9pIapuFZNLARQmM6n+DO/RednsY+NDx192PPK89PyO2Zfn69tm7nRMvwRZ27q4f1xOXTWYDk7po0+p0aixweNfV9OcGnVaI625+PjRqPrVeR+gFouQhaUTH5cGumePSE2FqvRr7IFKqfksafkmadPl3r/QWtZ7/WroTwOg1FTVS20Xeo90kzFpMDznsuML0mPfmT3Hsaph9qjX1CXIxnavNcX0CNQ885+OHZTGBo6vqbppZsRpzUXzjEatksqrc/pjKQSELGTX1JD3rCC1x3t+7MDsHkV1q6SWddKr35MKUeu97YY1Xm8cAMCJBSu8karWjbP3O+cFpOl2E6nHvvukJ747c15ZcPaVginVjTOhaXVXKji1zcyFqjuDm04yRMjCwsRGU3Oj0i7tRfZ6c6XSJ2+W13ijUm1d0ubtMyNSzevokg0AuWA2M43irDfNPjY26P3Ojrzg/Q4vr519N179GfSOyyJCFk4smfSutc+9tBfZN7sHjMwbfWpZ5113nwpSLeu90SruEAOA/FDd6F3eW3OR35WUBEIWvEmVx41Ipe7gmxybOa+y3gtP7a+bubTXvN4bqeLaOwAAsxCySkUykVoWIi1ITY1MpS8LY2XSsrXeBMnOy2ZGpJrXex28aX8AAEBGCFnFaOSItPe+2aNSA92zFzitbvSC07q3zA5STR004wQAFKRk0mloPK7B0bgGR2NyzunCtU2+1UPIKiaJuPTI30sP/bk3+bws6DWta14vbfj1tFYI66XaZr+rBQBgXs45jcYSOjoW12A0pqOp0HR0LK6j0ZgGR+M6mtoeHJ05fmwsrvQe62e11uqBT17m2/sgZBWL7oekHZ/y7hhZ91bpzTd7faYKdfkFAEBRiCeSOpoKRVMjTEeng9HU/ljqnJnQFEskT/g9aysCWlZTocbaci2rrtDqZdVqrKlQY025ltVUaFlNuRprKtRa5++VGT6BC93RA9K/f0Z69kdeY7ntd0sbtjF3CgCQVcmk0/D4pI6OzQ5Lg9G4N8I0mjbClBaWRibm6cWVUh4wLyylwtHa5hqdv2aZltWWT4emhmrva2OtF54aqstVGSyMXoqErEIVH5d++b+ln/2Vt/3Gz0qXfLzolygAACzOZCKp6ERCI7FJHZszwnQsdXlu1gjTWHx6JCp5guWOzaT6qvLpsNQSqtD65SE11JTPO8I09bWmIiAr4kEBQlYh2nOv9OM/9JZTeNW7pF//greOHwCg6MQTSY2mQlF0YlIjE5Pe9oS3HY3Nvy86kZg+P317YvLEl+EkqaYioGXV5dOX485YVu2FpOqZcNRYmzbCVFOh+upyBcqKNywtFCGrkPS/KP3k09Lee70WCx/4F+msN/pdFQAgTTyRnAlDsbTgM5EKOqlQNL09HZQSM+elQtHIxKRipwhFU8pMqq0MKlQZVO3UoyKgNbU1qX0B1VbMHAtVBtRQPXuEqaG6XFXlhXEprhAQsgpBLCr9/MvSw3/tLZ78tj+Ttl7vrVsFAMjYZCKpWCKp2KT3mJicvT37WGIm+MRmQtHMaNFMKEoPVJmGokCZqbYioFBlUDVpwae5tiYVhALevulgFEgLSEHVpF47tV0ZLCvqS2+FiJCVz5y
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Plotting the accuracy\n",
"\n",
"f = plt.figure(figsize=(10,10))\n",
"plt.plot(train_accuracy, label='Training accuracy')\n",
"plt.plot(test_accuracy, label='Testing accuracy')\n",
"plt.legend()\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prediction is: 1\n",
"Actual label is: 1\n"
]
}
],
"source": [
"img = test_dataset[31][0].resize_((1,1,28,28))\n",
"label = test_dataset[31][1]\n",
"\n",
"model.eval()\n",
"\n",
"if CUDA:\n",
" model = model.cuda()\n",
" img = img.cuda()\n",
" \n",
"output = model(img)\n",
"_, predicted = torch.max(output, dim=1)\n",
"print('Prediction is: {}'.format(predicted.item()))\n",
"print('Actual label is: {}'.format(label))\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7ff75ca94310>"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAL9klEQVR4nO3dX6gc5R3G8ecx6oXRi9jQQ9C0WhFjKDSGGAqNwSKKzU3ihWLAkFLJ8UJBoRcVixgoFSnV0ivlBMVYrCJoYhCppkGaehNyDGnMn2pSiZoQk0oujASxnvx6sRM5xrOzJzszO2t+3w8cdvd998+PIU/emXl39nVECMC577y2CwAwGIQdSIKwA0kQdiAJwg4kcf4gP8w2p/6BhkWEp2qvNLLbvtX2e7YP2H6wynsBaJb7nWe3PUPS+5JulnRI0nZJKyNib8lrGNmBhjUxsi+WdCAiPoiILyW9KGl5hfcD0KAqYb9M0seTHh8q2r7B9qjtcdvjFT4LQEWNn6CLiDFJYxK78UCbqozshyXNnfT48qINwBCqEvbtkq62faXtCyXdKWlTPWUBqFvfu/ER8ZXt+yS9IWmGpGciYk9tlQGoVd9Tb319GMfsQOMa+VINgO8Owg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Loe8lmfDds3769tP/kyZOl/atWrSrt/+ijj866JrSjUthtH5R0QtKEpK8iYlEdRQGoXx0j+88j4tMa3gdAgzhmB5KoGvaQ9Kbtd2yPTvUE26O2x22PV/wsABVU3Y1fEhGHbX9f0mbb/46IrZOfEBFjksYkyXZU/DwAfao0skfE4eL2mKQNkhbXURSA+vUddtszbV9y+r6kWyTtrqswAPWqshs/ImmD7dPv89eI+FstVWFglixZUtq/Zs2a0v6HH364znLQoL7DHhEfSPpJjbUAaBBTb0AShB1IgrADSRB2IAnCDiTBJa7nuHXr1pX2P/XUU6X9s2fPrrMctIiRHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJ49uQh+PCgLRnYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSMKDnGdlRZjBmzdvXmn/3r17S/t7/fuYMWPGWdeEZkWEp2pnZAeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJJhnT25iYqK0v9e/j9tvv720f8OGDWddE6rpe57d9jO2j9nePantUtubbe8vbmfVWSyA+k1nN/5ZSbee0fagpC0RcbWkLcVjAEOsZ9gjYquk42c0L5e0vri/XtKKmusCULN+f4NuJCKOFPc/kTTS7Ym2RyWN9vk5AGpS+QcnIyLKTrxFxJikMYkTdECb+p16O2p7jiQVt8fqKwlAE/oN+yZJq4v7qyW9Wk85AJrSc57d9guSbpQ0W9JRSY9I2ijpJUk/kPShpDsi4syTeFO9F7vxQ6bX+uxr1qwp7d+xY0dp//XXX3/WNaGabvPsPY/ZI2Jll66bKlUEYKD4uiyQBGEHkiDsQBKEHUiCsANJsGQzSrGk87mDkR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmCePbm33367tL/XJa4zZ84s7b/ooou69p08ebL0tagXIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJMGSzcmVzYNL0rZt20r758+fX9pf9lPSvX6GGv3pe8lmAOcGwg4kQdiBJAg7kARhB5Ig7EAShB1IguvZk+t1TfkXX3xR2m9POaX7taVLl3btY559sHqO7LafsX3M9u5JbWttH7a9s/hb1myZAKqazm78s5JunaL9TxGxoPh7vd6yANStZ9gjYquk4wOoBUCDqpygu8/2rmI3f1a3J9ketT1ue7zCZwGoqN+wPynpKkkLJB2R9Hi3J0bEWEQsiohFfX4WgBr0FfaIOBoRExFxStI6SYvrLQtA3foKu+05kx7eJml3t+cCGA4959ltvyDpRkmzbR+S9IikG20vkBSSDkq6p8Ea0aJ9+/aV9i9cuLC0/5prrqmzHFTQM+wRsXKK5qcbqAVAg/i6LJAEYQeSIOxAEoQdSIKwA0lwiStK9VrS+a677hpQJaiKkR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmCeHZX0WvL72muvHVAl6IWRHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJ4dlfRasvmGG24YUCXohZEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Jgnh2V9LqevVc/BqfnyG57ru23bO+1vcf2/UX7pbY3295f3M5qvlwA/ZrObvxXkn4dEfMl/VTSvbbnS3pQ0paIuFrSluIxgCHVM+wRcSQidhT3T0jaJ+kyScslrS+etl7SiqaKBFDdWR2z275C0nWStkkaiYgjRdcnkka6vGZU0mj/JQKow7TPxtu+WNLLkh6IiM8m90XnLMyUZ2IiYiwiFkXEokqVAqhkWmG3fYE6QX8+Il4pmo/anlP0z5F0rJkSAdSh5268O9cwPi1pX0Q8Malrk6TVkh4rbl9tpEK0auvWraX9551XPl6cOnWqznJQwXSO2X8maZWkd23vLNoeUifkL9m+W9KHku5opkQAdegZ9oh4W1K3Xyi4qd5yADSFr8sCSRB2IAnCDiRB2IEkCDuQhAd5CaJtrnc8x0xMTJT2l/37Ov98rrBuQkRMOXvGyA4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDRiUo2btxY2r9iRfefJly6dGnpa3tdS4+zw8gOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kwz45KHn300dL+5cuXd+2bN29e6WuZZ68XIzuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJNHzd+Ntz5X0nKQRSSFpLCL+bHutpDWS/ls89aGIeL3He/G78UDDuv1u/HTCPkfSnIjYYfsSSe9IWqHOeuyfR8Qfp1sEYQea1y3s01mf/YikI8X9E7b3Sbqs3vIANO2sjtltXyHpOknbiqb7bO+y/YztWV1eM2p73PZ4pUoBVDLttd5sXyzpH5J+HxGv2B6R9Kk6x/G/U2dX/1c93oPdeKBhfR+zS5LtCyS9JumNiHhiiv4rJL0WET/u8T6EHWhY3ws72rakpyXtmxz04sTdabdJ2l21SADNmc7Z+CWS/inpXUmniuaHJK2UtECd3fiDku4pTuaVvRcjO9CwSrvxdSHsQPNYnx1IjrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5DEoJds/lTSh5Mezy7ahtGw1jasdUnU1q86a/tht46BXs/+rQ+3xyNiUWsFlBjW2oa1Lona+jWo2tiNB5Ig7EASbYd9rOXPLzOstQ1rXRK19WsgtbV6zA5gcNoe2QEMCGEHkmgl7LZvtf2e7QO2H2yjhm5sH7T9ru2dba9PV6yhd8z27kltl9rebHt/cTvlGnst1bbW9uFi2+20vayl2ubafsv2Xtt7bN9ftLe67UrqGsh2G/gxu+0Zkt6XdLOkQ5K2S1oZEXsHWkgXtg9KWhQRrX8Bw/ZSSZ9Leu700lq2/yDpeEQ8VvxHOSsifjMkta3VWS7j3VBt3ZYZ/6Va3HZ1Ln/ejzZG9sWSDkTEBxHxpaQXJS1voY6hFxFbJR0/o3m5pPXF/fXq/GMZuC61DYWIOBIRO4r7JySdXma81W1XUtdAtBH2yyR9POnxIQ3Xeu8h6U3b79gebbuYKYxMWmbrE0kjbRYzhZ7LeA/SGcuMD82262f586o4QfdtSyJioaRfSLq32F0dStE5BhumudMnJV2lzhqARyQ93mYxxTLjL0t6ICI+m9zX5raboq6BbLc2wn5Y0txJjy8v2oZCRBwubo9J2qDOYccwOXp6Bd3i9ljL9XwtIo5GxEREnJK0Ti1uu2KZ8ZclPR8RrxTNrW+7qeoa1HZrI+zbJV1t+0rbF0q6U9KmFur4FtszixMnsj1T0i0
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"random_img = test_dataset[31][0].numpy() * stddev_gray + mean_gray\n",
"plt.imshow(random_img.reshape(28,28), cmap='gray')"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"def predict(img_name,model,show_image=False):\n",
" image = cv2.imread(img_name,0)\n",
" ret, thresholded = cv2.threshold(image,127,255,cv2.THRESH_BINARY)\n",
" img = 255 - thresholded\n",
" if show_image == True:\n",
" cv2.imshow('Original',img)\n",
" cv2.waitKey(0)\n",
" cv2.destroyAllWindows()\n",
" \n",
" img = Image.fromarray(img)\n",
" img = transforms_photo(img)\n",
" img = img.view(1,1,28,28)\n",
" #img = Variable(img)\n",
" \n",
" model.eval()\n",
" \n",
" if CUDA:\n",
" model = model.cuda()\n",
" img = img.cuda()\n",
" \n",
" output = model(img)\n",
" print(output)\n",
" print(output.data)\n",
" _, predicted = torch.max(output,1)\n",
" return predicted.item()\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[-5.4035, 2.6434, 8.6153, -3.0462, -7.2718, -7.2775, -1.7709, 0.8279,\n",
" -4.8467, -8.3950]], device='cuda:0', grad_fn=<AddmmBackward>)\n",
"tensor([[-5.4035, 2.6434, 8.6153, -3.0462, -7.2718, -7.2775, -1.7709, 0.8279,\n",
" -4.8467, -8.3950]], device='cuda:0')\n",
"the predicted label is 2\n"
]
}
],
"source": [
"pred = predict('photo.jpg',model)\n",
"print('the predicted label is {}'.format(pred))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
2020-05-02 21:40:17 +00:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2020-05-15 14:33:24 +00:00
"version": "3.7.7"
2020-05-02 21:40:17 +00:00
}
},
"nbformat": 4,
"nbformat_minor": 4
}