[MOD] Finished diabetes classification NN added visualization of NN training
This commit is contained in:
parent
21b5ce8dd4
commit
1e8d40791c
|
@ -2,9 +2,18 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/eddie/.pyenv/versions/3.7.6/envs/pytorch/lib/python3.7/site-packages/pandas/compat/__init__.py:117: UserWarning: Could not import the lzma module. Your installed Python is incomplete. Attempting to use lzma compression will result in a RuntimeError.\n",
|
||||
" warnings.warn(msg)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
|
@ -374,7 +383,7 @@
|
|||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<torch.utils.data.dataloader.DataLoader at 0x12ced9810>"
|
||||
"<torch.utils.data.dataloader.DataLoader at 0x124dd84d0>"
|
||||
]
|
||||
},
|
||||
"execution_count": 19,
|
||||
|
@ -408,7 +417,7 @@
|
|||
"for (x,y) in train_loader:\n",
|
||||
" print(\"For one iteration (baatch) there are:\")\n",
|
||||
" print(\"Data: {}\".format(x.shape))\n",
|
||||
" print(\"lables: {}\".format(y.shape))\n",
|
||||
" print(\"labels: {}\".format(y.shape))\n",
|
||||
" break\n",
|
||||
" "
|
||||
]
|
||||
|
@ -465,7 +474,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -480,14 +489,251 @@
|
|||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch: 1/200, Loss: 0.551, Accuracy: 0.719\n",
|
||||
"Epoch: 2/200, Loss: 0.473, Accuracy: 0.781\n",
|
||||
"Epoch: 3/200, Loss: 0.536, Accuracy: 0.750\n",
|
||||
"Epoch: 4/200, Loss: 0.519, Accuracy: 0.656\n",
|
||||
"Epoch: 5/200, Loss: 0.453, Accuracy: 0.781\n",
|
||||
"Epoch: 6/200, Loss: 0.666, Accuracy: 0.594\n",
|
||||
"Epoch: 7/200, Loss: 0.455, Accuracy: 0.719\n",
|
||||
"Epoch: 8/200, Loss: 0.617, Accuracy: 0.656\n",
|
||||
"Epoch: 9/200, Loss: 0.367, Accuracy: 0.875\n",
|
||||
"Epoch: 10/200, Loss: 0.496, Accuracy: 0.719\n",
|
||||
"Epoch: 11/200, Loss: 0.483, Accuracy: 0.812\n",
|
||||
"Epoch: 12/200, Loss: 0.556, Accuracy: 0.656\n",
|
||||
"Epoch: 13/200, Loss: 0.455, Accuracy: 0.719\n",
|
||||
"Epoch: 14/200, Loss: 0.504, Accuracy: 0.781\n",
|
||||
"Epoch: 15/200, Loss: 0.492, Accuracy: 0.750\n",
|
||||
"Epoch: 16/200, Loss: 0.603, Accuracy: 0.781\n",
|
||||
"Epoch: 17/200, Loss: 0.417, Accuracy: 0.906\n",
|
||||
"Epoch: 18/200, Loss: 0.450, Accuracy: 0.844\n",
|
||||
"Epoch: 19/200, Loss: 0.313, Accuracy: 0.844\n",
|
||||
"Epoch: 20/200, Loss: 0.642, Accuracy: 0.656\n",
|
||||
"Epoch: 21/200, Loss: 0.369, Accuracy: 0.844\n",
|
||||
"Epoch: 22/200, Loss: 0.499, Accuracy: 0.781\n",
|
||||
"Epoch: 23/200, Loss: 0.477, Accuracy: 0.781\n",
|
||||
"Epoch: 24/200, Loss: 0.440, Accuracy: 0.750\n",
|
||||
"Epoch: 25/200, Loss: 0.523, Accuracy: 0.781\n",
|
||||
"Epoch: 26/200, Loss: 0.362, Accuracy: 0.906\n",
|
||||
"Epoch: 27/200, Loss: 0.424, Accuracy: 0.875\n",
|
||||
"Epoch: 28/200, Loss: 0.389, Accuracy: 0.781\n",
|
||||
"Epoch: 29/200, Loss: 0.534, Accuracy: 0.688\n",
|
||||
"Epoch: 30/200, Loss: 0.435, Accuracy: 0.844\n",
|
||||
"Epoch: 31/200, Loss: 0.315, Accuracy: 0.875\n",
|
||||
"Epoch: 32/200, Loss: 0.574, Accuracy: 0.688\n",
|
||||
"Epoch: 33/200, Loss: 0.485, Accuracy: 0.750\n",
|
||||
"Epoch: 34/200, Loss: 0.373, Accuracy: 0.781\n",
|
||||
"Epoch: 35/200, Loss: 0.406, Accuracy: 0.844\n",
|
||||
"Epoch: 36/200, Loss: 0.497, Accuracy: 0.781\n",
|
||||
"Epoch: 37/200, Loss: 0.276, Accuracy: 0.969\n",
|
||||
"Epoch: 38/200, Loss: 0.533, Accuracy: 0.719\n",
|
||||
"Epoch: 39/200, Loss: 0.465, Accuracy: 0.781\n",
|
||||
"Epoch: 40/200, Loss: 0.435, Accuracy: 0.750\n",
|
||||
"Epoch: 41/200, Loss: 0.435, Accuracy: 0.812\n",
|
||||
"Epoch: 42/200, Loss: 0.347, Accuracy: 0.812\n",
|
||||
"Epoch: 43/200, Loss: 0.574, Accuracy: 0.781\n",
|
||||
"Epoch: 44/200, Loss: 0.383, Accuracy: 0.844\n",
|
||||
"Epoch: 45/200, Loss: 0.550, Accuracy: 0.750\n",
|
||||
"Epoch: 46/200, Loss: 0.424, Accuracy: 0.844\n",
|
||||
"Epoch: 47/200, Loss: 0.647, Accuracy: 0.625\n",
|
||||
"Epoch: 48/200, Loss: 0.469, Accuracy: 0.719\n",
|
||||
"Epoch: 49/200, Loss: 0.360, Accuracy: 0.781\n",
|
||||
"Epoch: 50/200, Loss: 0.303, Accuracy: 0.938\n",
|
||||
"Epoch: 51/200, Loss: 0.523, Accuracy: 0.781\n",
|
||||
"Epoch: 52/200, Loss: 0.532, Accuracy: 0.719\n",
|
||||
"Epoch: 53/200, Loss: 0.491, Accuracy: 0.812\n",
|
||||
"Epoch: 54/200, Loss: 0.337, Accuracy: 0.812\n",
|
||||
"Epoch: 55/200, Loss: 0.404, Accuracy: 0.781\n",
|
||||
"Epoch: 56/200, Loss: 0.383, Accuracy: 0.875\n",
|
||||
"Epoch: 57/200, Loss: 0.574, Accuracy: 0.688\n",
|
||||
"Epoch: 58/200, Loss: 0.449, Accuracy: 0.781\n",
|
||||
"Epoch: 59/200, Loss: 0.369, Accuracy: 0.875\n",
|
||||
"Epoch: 60/200, Loss: 0.615, Accuracy: 0.812\n",
|
||||
"Epoch: 61/200, Loss: 0.506, Accuracy: 0.781\n",
|
||||
"Epoch: 62/200, Loss: 0.595, Accuracy: 0.750\n",
|
||||
"Epoch: 63/200, Loss: 0.421, Accuracy: 0.781\n",
|
||||
"Epoch: 64/200, Loss: 0.387, Accuracy: 0.812\n",
|
||||
"Epoch: 65/200, Loss: 0.358, Accuracy: 0.844\n",
|
||||
"Epoch: 66/200, Loss: 0.413, Accuracy: 0.812\n",
|
||||
"Epoch: 67/200, Loss: 0.445, Accuracy: 0.719\n",
|
||||
"Epoch: 68/200, Loss: 0.397, Accuracy: 0.750\n",
|
||||
"Epoch: 69/200, Loss: 0.485, Accuracy: 0.781\n",
|
||||
"Epoch: 70/200, Loss: 0.449, Accuracy: 0.750\n",
|
||||
"Epoch: 71/200, Loss: 0.385, Accuracy: 0.812\n",
|
||||
"Epoch: 72/200, Loss: 0.587, Accuracy: 0.688\n",
|
||||
"Epoch: 73/200, Loss: 0.414, Accuracy: 0.906\n",
|
||||
"Epoch: 74/200, Loss: 0.466, Accuracy: 0.844\n",
|
||||
"Epoch: 75/200, Loss: 0.441, Accuracy: 0.875\n",
|
||||
"Epoch: 76/200, Loss: 0.386, Accuracy: 0.781\n",
|
||||
"Epoch: 77/200, Loss: 0.482, Accuracy: 0.781\n",
|
||||
"Epoch: 78/200, Loss: 0.456, Accuracy: 0.812\n",
|
||||
"Epoch: 79/200, Loss: 0.216, Accuracy: 0.969\n",
|
||||
"Epoch: 80/200, Loss: 0.457, Accuracy: 0.719\n",
|
||||
"Epoch: 81/200, Loss: 0.481, Accuracy: 0.750\n",
|
||||
"Epoch: 82/200, Loss: 0.587, Accuracy: 0.719\n",
|
||||
"Epoch: 83/200, Loss: 0.338, Accuracy: 0.844\n",
|
||||
"Epoch: 84/200, Loss: 0.560, Accuracy: 0.688\n",
|
||||
"Epoch: 85/200, Loss: 0.343, Accuracy: 0.875\n",
|
||||
"Epoch: 86/200, Loss: 0.383, Accuracy: 0.844\n",
|
||||
"Epoch: 87/200, Loss: 0.355, Accuracy: 0.812\n",
|
||||
"Epoch: 88/200, Loss: 0.353, Accuracy: 0.875\n",
|
||||
"Epoch: 89/200, Loss: 0.334, Accuracy: 0.812\n",
|
||||
"Epoch: 90/200, Loss: 0.264, Accuracy: 0.875\n",
|
||||
"Epoch: 91/200, Loss: 0.473, Accuracy: 0.781\n",
|
||||
"Epoch: 92/200, Loss: 0.570, Accuracy: 0.750\n",
|
||||
"Epoch: 93/200, Loss: 0.317, Accuracy: 0.844\n",
|
||||
"Epoch: 94/200, Loss: 0.356, Accuracy: 0.781\n",
|
||||
"Epoch: 95/200, Loss: 0.541, Accuracy: 0.812\n",
|
||||
"Epoch: 96/200, Loss: 0.352, Accuracy: 0.812\n",
|
||||
"Epoch: 97/200, Loss: 0.427, Accuracy: 0.750\n",
|
||||
"Epoch: 98/200, Loss: 0.367, Accuracy: 0.812\n",
|
||||
"Epoch: 99/200, Loss: 0.456, Accuracy: 0.719\n",
|
||||
"Epoch: 100/200, Loss: 0.326, Accuracy: 0.781\n",
|
||||
"Epoch: 101/200, Loss: 0.477, Accuracy: 0.719\n",
|
||||
"Epoch: 102/200, Loss: 0.310, Accuracy: 0.844\n",
|
||||
"Epoch: 103/200, Loss: 0.236, Accuracy: 0.906\n",
|
||||
"Epoch: 104/200, Loss: 0.553, Accuracy: 0.625\n",
|
||||
"Epoch: 105/200, Loss: 0.323, Accuracy: 0.906\n",
|
||||
"Epoch: 106/200, Loss: 0.458, Accuracy: 0.688\n",
|
||||
"Epoch: 107/200, Loss: 0.505, Accuracy: 0.719\n",
|
||||
"Epoch: 108/200, Loss: 0.493, Accuracy: 0.781\n",
|
||||
"Epoch: 109/200, Loss: 0.492, Accuracy: 0.688\n",
|
||||
"Epoch: 110/200, Loss: 0.400, Accuracy: 0.719\n",
|
||||
"Epoch: 111/200, Loss: 0.423, Accuracy: 0.844\n",
|
||||
"Epoch: 112/200, Loss: 0.549, Accuracy: 0.750\n",
|
||||
"Epoch: 113/200, Loss: 0.564, Accuracy: 0.750\n",
|
||||
"Epoch: 114/200, Loss: 0.349, Accuracy: 0.844\n",
|
||||
"Epoch: 115/200, Loss: 0.438, Accuracy: 0.812\n",
|
||||
"Epoch: 116/200, Loss: 0.327, Accuracy: 0.812\n",
|
||||
"Epoch: 117/200, Loss: 0.697, Accuracy: 0.688\n",
|
||||
"Epoch: 118/200, Loss: 0.555, Accuracy: 0.781\n",
|
||||
"Epoch: 119/200, Loss: 0.500, Accuracy: 0.750\n",
|
||||
"Epoch: 120/200, Loss: 0.376, Accuracy: 0.812\n",
|
||||
"Epoch: 121/200, Loss: 0.522, Accuracy: 0.750\n",
|
||||
"Epoch: 122/200, Loss: 0.292, Accuracy: 0.906\n",
|
||||
"Epoch: 123/200, Loss: 0.475, Accuracy: 0.750\n",
|
||||
"Epoch: 124/200, Loss: 0.394, Accuracy: 0.812\n",
|
||||
"Epoch: 125/200, Loss: 0.293, Accuracy: 0.812\n",
|
||||
"Epoch: 126/200, Loss: 0.312, Accuracy: 0.875\n",
|
||||
"Epoch: 127/200, Loss: 0.516, Accuracy: 0.781\n",
|
||||
"Epoch: 128/200, Loss: 0.433, Accuracy: 0.750\n",
|
||||
"Epoch: 129/200, Loss: 0.376, Accuracy: 0.812\n",
|
||||
"Epoch: 130/200, Loss: 0.279, Accuracy: 0.844\n",
|
||||
"Epoch: 131/200, Loss: 0.486, Accuracy: 0.750\n",
|
||||
"Epoch: 132/200, Loss: 0.542, Accuracy: 0.562\n",
|
||||
"Epoch: 133/200, Loss: 0.365, Accuracy: 0.812\n",
|
||||
"Epoch: 134/200, Loss: 0.533, Accuracy: 0.844\n",
|
||||
"Epoch: 135/200, Loss: 0.341, Accuracy: 0.844\n",
|
||||
"Epoch: 136/200, Loss: 0.282, Accuracy: 0.906\n",
|
||||
"Epoch: 137/200, Loss: 0.490, Accuracy: 0.719\n",
|
||||
"Epoch: 138/200, Loss: 0.481, Accuracy: 0.719\n",
|
||||
"Epoch: 139/200, Loss: 0.435, Accuracy: 0.812\n",
|
||||
"Epoch: 140/200, Loss: 0.622, Accuracy: 0.750\n",
|
||||
"Epoch: 141/200, Loss: 0.346, Accuracy: 0.781\n",
|
||||
"Epoch: 142/200, Loss: 0.258, Accuracy: 0.938\n",
|
||||
"Epoch: 143/200, Loss: 0.407, Accuracy: 0.781\n",
|
||||
"Epoch: 144/200, Loss: 0.461, Accuracy: 0.750\n",
|
||||
"Epoch: 145/200, Loss: 0.350, Accuracy: 0.812\n",
|
||||
"Epoch: 146/200, Loss: 0.399, Accuracy: 0.844\n",
|
||||
"Epoch: 147/200, Loss: 0.358, Accuracy: 0.781\n",
|
||||
"Epoch: 148/200, Loss: 0.357, Accuracy: 0.781\n",
|
||||
"Epoch: 149/200, Loss: 0.426, Accuracy: 0.781\n",
|
||||
"Epoch: 150/200, Loss: 0.349, Accuracy: 0.844\n",
|
||||
"Epoch: 151/200, Loss: 0.619, Accuracy: 0.719\n",
|
||||
"Epoch: 152/200, Loss: 0.418, Accuracy: 0.781\n",
|
||||
"Epoch: 153/200, Loss: 0.604, Accuracy: 0.625\n",
|
||||
"Epoch: 154/200, Loss: 0.397, Accuracy: 0.781\n",
|
||||
"Epoch: 155/200, Loss: 0.366, Accuracy: 0.875\n",
|
||||
"Epoch: 156/200, Loss: 0.309, Accuracy: 0.875\n",
|
||||
"Epoch: 157/200, Loss: 0.337, Accuracy: 0.812\n",
|
||||
"Epoch: 158/200, Loss: 0.404, Accuracy: 0.875\n",
|
||||
"Epoch: 159/200, Loss: 0.362, Accuracy: 0.844\n",
|
||||
"Epoch: 160/200, Loss: 0.400, Accuracy: 0.812\n",
|
||||
"Epoch: 161/200, Loss: 0.422, Accuracy: 0.812\n",
|
||||
"Epoch: 162/200, Loss: 0.290, Accuracy: 0.844\n",
|
||||
"Epoch: 163/200, Loss: 0.360, Accuracy: 0.781\n",
|
||||
"Epoch: 164/200, Loss: 0.552, Accuracy: 0.719\n",
|
||||
"Epoch: 165/200, Loss: 0.438, Accuracy: 0.781\n",
|
||||
"Epoch: 166/200, Loss: 0.429, Accuracy: 0.781\n",
|
||||
"Epoch: 167/200, Loss: 0.322, Accuracy: 0.938\n",
|
||||
"Epoch: 168/200, Loss: 0.383, Accuracy: 0.812\n",
|
||||
"Epoch: 169/200, Loss: 0.484, Accuracy: 0.656\n",
|
||||
"Epoch: 170/200, Loss: 0.398, Accuracy: 0.875\n",
|
||||
"Epoch: 171/200, Loss: 0.296, Accuracy: 0.906\n",
|
||||
"Epoch: 172/200, Loss: 0.491, Accuracy: 0.719\n",
|
||||
"Epoch: 173/200, Loss: 0.729, Accuracy: 0.688\n",
|
||||
"Epoch: 174/200, Loss: 0.481, Accuracy: 0.812\n",
|
||||
"Epoch: 175/200, Loss: 0.432, Accuracy: 0.688\n",
|
||||
"Epoch: 176/200, Loss: 0.344, Accuracy: 0.781\n",
|
||||
"Epoch: 177/200, Loss: 0.284, Accuracy: 0.906\n",
|
||||
"Epoch: 178/200, Loss: 0.452, Accuracy: 0.750\n",
|
||||
"Epoch: 179/200, Loss: 0.419, Accuracy: 0.781\n",
|
||||
"Epoch: 180/200, Loss: 0.505, Accuracy: 0.750\n",
|
||||
"Epoch: 181/200, Loss: 0.372, Accuracy: 0.844\n",
|
||||
"Epoch: 182/200, Loss: 0.429, Accuracy: 0.750\n",
|
||||
"Epoch: 183/200, Loss: 0.579, Accuracy: 0.688\n",
|
||||
"Epoch: 184/200, Loss: 0.386, Accuracy: 0.812\n",
|
||||
"Epoch: 185/200, Loss: 0.359, Accuracy: 0.812\n",
|
||||
"Epoch: 186/200, Loss: 0.402, Accuracy: 0.781\n",
|
||||
"Epoch: 187/200, Loss: 0.542, Accuracy: 0.688\n",
|
||||
"Epoch: 188/200, Loss: 0.388, Accuracy: 0.875\n",
|
||||
"Epoch: 189/200, Loss: 0.410, Accuracy: 0.875\n",
|
||||
"Epoch: 190/200, Loss: 0.344, Accuracy: 0.812\n",
|
||||
"Epoch: 191/200, Loss: 0.484, Accuracy: 0.844\n",
|
||||
"Epoch: 192/200, Loss: 0.480, Accuracy: 0.719\n",
|
||||
"Epoch: 193/200, Loss: 0.560, Accuracy: 0.750\n",
|
||||
"Epoch: 194/200, Loss: 0.466, Accuracy: 0.750\n",
|
||||
"Epoch: 195/200, Loss: 0.238, Accuracy: 0.938\n",
|
||||
"Epoch: 196/200, Loss: 0.423, Accuracy: 0.844\n",
|
||||
"Epoch: 197/200, Loss: 0.366, Accuracy: 0.812\n",
|
||||
"Epoch: 198/200, Loss: 0.356, Accuracy: 0.875\n",
|
||||
"Epoch: 199/200, Loss: 0.373, Accuracy: 0.812\n",
|
||||
"Epoch: 200/200, Loss: 0.473, Accuracy: 0.781\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Training the network\n",
|
||||
"epochs = 200\n",
|
||||
"for epoch in range(epochs):\n",
|
||||
" for inputs, labels in train_loader:\n",
|
||||
" inputs = inputs.float()\n",
|
||||
" labels = labels.float()\n",
|
||||
" # Forward propagation\n",
|
||||
" outputs = net(inputs)\n",
|
||||
" # Loss calculation\n",
|
||||
" loss = criterion(outputs,labels)\n",
|
||||
" # Clear gradient buffer (w <- w - lr*gradient)\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" # Back propagation\n",
|
||||
" loss.backward()\n",
|
||||
" # Update weights\n",
|
||||
" optimizer.step()\n",
|
||||
" \n",
|
||||
" # Accuracy caalculation\n",
|
||||
" output = (outputs > 0.5).float()\n",
|
||||
" accuracy = (output == labels).float().mean()\n",
|
||||
" \n",
|
||||
" # Print statistics\n",
|
||||
" print(\"Epoch: {}/{}, Loss: {:.3f}, Accuracy: {:.3f}\".format(epoch+1, epochs, loss, accuracy))\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
" "
|
||||
]
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue