pytorch-stuff/Diabetes_NN.ipynb

761 lines
24 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/eddie/.pyenv/versions/3.7.6/envs/pytorch/lib/python3.7/site-packages/pandas/compat/__init__.py:117: UserWarning: Could not import the lzma module. Your installed Python is incomplete. Attempting to use lzma compression will result in a RuntimeError.\n",
" warnings.warn(msg)\n"
]
}
],
"source": [
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import pandas as pd\n",
"from sklearn.preprocessing import StandardScaler\n",
"from torch.utils.data import Dataset\n",
"from torch.utils.data import DataLoader\n",
"from torch.nn import BCELoss\n",
"from torch.optim import SGD"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Load the data set using pandas\n",
"data = pd.read_csv('diabetes.csv')\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Number of times pregnant</th>\n",
" <th>Plasma glucose concentration</th>\n",
" <th>Diastolic blood pressure</th>\n",
" <th>Triceps skin fold thickness</th>\n",
" <th>2-Hour serum insulin</th>\n",
" <th>Body mass index</th>\n",
" <th>Age</th>\n",
" <th>Class</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>6</td>\n",
" <td>148</td>\n",
" <td>72</td>\n",
" <td>35</td>\n",
" <td>0</td>\n",
" <td>33.6</td>\n",
" <td>50</td>\n",
" <td>positive</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>85</td>\n",
" <td>66</td>\n",
" <td>29</td>\n",
" <td>0</td>\n",
" <td>26.6</td>\n",
" <td>31</td>\n",
" <td>negative</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>8</td>\n",
" <td>183</td>\n",
" <td>64</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>23.3</td>\n",
" <td>32</td>\n",
" <td>positive</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>89</td>\n",
" <td>66</td>\n",
" <td>23</td>\n",
" <td>94</td>\n",
" <td>28.1</td>\n",
" <td>21</td>\n",
" <td>negative</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>137</td>\n",
" <td>40</td>\n",
" <td>35</td>\n",
" <td>168</td>\n",
" <td>43.1</td>\n",
" <td>33</td>\n",
" <td>positive</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Number of times pregnant Plasma glucose concentration \\\n",
"0 6 148 \n",
"1 1 85 \n",
"2 8 183 \n",
"3 1 89 \n",
"4 0 137 \n",
"\n",
" Diastolic blood pressure Triceps skin fold thickness \\\n",
"0 72 35 \n",
"1 66 29 \n",
"2 64 0 \n",
"3 66 23 \n",
"4 40 35 \n",
"\n",
" 2-Hour serum insulin Body mass index Age Class \n",
"0 0 33.6 50 positive \n",
"1 0 26.6 31 negative \n",
"2 0 23.3 32 positive \n",
"3 94 28.1 21 negative \n",
"4 168 43.1 33 positive "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.head() "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"x = data.iloc[:,0:-1].values"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(768, 7)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"y_string = list(data.iloc[:,-1])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"768"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(y_string)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"y_int = []\n",
"for i in y_string:\n",
" if i == 'positive':\n",
" y_int.append(1)\n",
" else:\n",
" y_int.append(0)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"y = np.array(y_int, dtype='float64') "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# data normaalization\n",
"sc = StandardScaler()\n",
"x = sc.fit_transform(x)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"x = torch.tensor(x)\n",
"y = torch.tensor(y)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([768, 7])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x.shape"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"y = y.unsqueeze(1)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([768, 1])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.shape"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"class Dataset(Dataset):\n",
" def __init__(self,x,y):\n",
" self.x = x\n",
" self.y = y\n",
" \n",
" def __getitem__(self, index):\n",
" return self.x[index], self.y[index]\n",
" \n",
" def __len__(self):\n",
" return len(self.x)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"dataset = Dataset(x,y)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"768"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(dataset)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"train_loader = DataLoader(dataset=dataset,\n",
" batch_size=32,\n",
" shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch.utils.data.dataloader.DataLoader at 0x124dd84d0>"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_loader"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"There are 24 batches in in the dataset\n",
"For one iteration (baatch) there are:\n",
"Data: torch.Size([32, 7])\n",
"lables: torch.Size([32, 1])\n"
]
}
],
"source": [
"# visualization of the train loader\n",
"print(\"There are {} batches in in the dataset\".format(len(train_loader)))\n",
"for (x,y) in train_loader:\n",
" print(\"For one iteration (baatch) there are:\")\n",
" print(\"Data: {}\".format(x.shape))\n",
" print(\"labels: {}\".format(y.shape))\n",
" break\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"24.0"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"768/32"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, input_features,output_features): #,hidden_layer_1,hidden_layer_2,): caan be done this way or hard coded\n",
" super(Model, self).__init__()\n",
" self.fc1 = nn.Linear(input_features, 5)\n",
" self.fc2 = nn.Linear(5,4)\n",
" self.fc3 = nn.Linear(4,3)\n",
" self.fc4 = nn.Linear(3,output_features)\n",
" self.sigmoid = nn.Sigmoid()\n",
" self.tanh = nn.Tanh()\n",
" \n",
" def forward(self,x):\n",
" out = self.fc1(x)\n",
" out = self.tanh(out)\n",
" out = self.fc2(out)\n",
" out = self.tanh(out)\n",
" out = self.fc3(out)\n",
" out = self.tanh(out)\n",
" out = self.fc4(out)\n",
" out = self.sigmoid(out)\n",
" \n",
" return out\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# The definition of the class Model object nets\n",
"net = Model(7,1)\n",
"\n",
"# Binary cross entropy was chosen as a the loss function\n",
"criterion = BCELoss(reduction='mean')\n",
"\n",
"# Define the optimizer\n",
"optimizer = SGD(net.parameters(), lr=0.1, momentum=0.9)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch: 1/200, Loss: 0.551, Accuracy: 0.719\n",
"Epoch: 2/200, Loss: 0.473, Accuracy: 0.781\n",
"Epoch: 3/200, Loss: 0.536, Accuracy: 0.750\n",
"Epoch: 4/200, Loss: 0.519, Accuracy: 0.656\n",
"Epoch: 5/200, Loss: 0.453, Accuracy: 0.781\n",
"Epoch: 6/200, Loss: 0.666, Accuracy: 0.594\n",
"Epoch: 7/200, Loss: 0.455, Accuracy: 0.719\n",
"Epoch: 8/200, Loss: 0.617, Accuracy: 0.656\n",
"Epoch: 9/200, Loss: 0.367, Accuracy: 0.875\n",
"Epoch: 10/200, Loss: 0.496, Accuracy: 0.719\n",
"Epoch: 11/200, Loss: 0.483, Accuracy: 0.812\n",
"Epoch: 12/200, Loss: 0.556, Accuracy: 0.656\n",
"Epoch: 13/200, Loss: 0.455, Accuracy: 0.719\n",
"Epoch: 14/200, Loss: 0.504, Accuracy: 0.781\n",
"Epoch: 15/200, Loss: 0.492, Accuracy: 0.750\n",
"Epoch: 16/200, Loss: 0.603, Accuracy: 0.781\n",
"Epoch: 17/200, Loss: 0.417, Accuracy: 0.906\n",
"Epoch: 18/200, Loss: 0.450, Accuracy: 0.844\n",
"Epoch: 19/200, Loss: 0.313, Accuracy: 0.844\n",
"Epoch: 20/200, Loss: 0.642, Accuracy: 0.656\n",
"Epoch: 21/200, Loss: 0.369, Accuracy: 0.844\n",
"Epoch: 22/200, Loss: 0.499, Accuracy: 0.781\n",
"Epoch: 23/200, Loss: 0.477, Accuracy: 0.781\n",
"Epoch: 24/200, Loss: 0.440, Accuracy: 0.750\n",
"Epoch: 25/200, Loss: 0.523, Accuracy: 0.781\n",
"Epoch: 26/200, Loss: 0.362, Accuracy: 0.906\n",
"Epoch: 27/200, Loss: 0.424, Accuracy: 0.875\n",
"Epoch: 28/200, Loss: 0.389, Accuracy: 0.781\n",
"Epoch: 29/200, Loss: 0.534, Accuracy: 0.688\n",
"Epoch: 30/200, Loss: 0.435, Accuracy: 0.844\n",
"Epoch: 31/200, Loss: 0.315, Accuracy: 0.875\n",
"Epoch: 32/200, Loss: 0.574, Accuracy: 0.688\n",
"Epoch: 33/200, Loss: 0.485, Accuracy: 0.750\n",
"Epoch: 34/200, Loss: 0.373, Accuracy: 0.781\n",
"Epoch: 35/200, Loss: 0.406, Accuracy: 0.844\n",
"Epoch: 36/200, Loss: 0.497, Accuracy: 0.781\n",
"Epoch: 37/200, Loss: 0.276, Accuracy: 0.969\n",
"Epoch: 38/200, Loss: 0.533, Accuracy: 0.719\n",
"Epoch: 39/200, Loss: 0.465, Accuracy: 0.781\n",
"Epoch: 40/200, Loss: 0.435, Accuracy: 0.750\n",
"Epoch: 41/200, Loss: 0.435, Accuracy: 0.812\n",
"Epoch: 42/200, Loss: 0.347, Accuracy: 0.812\n",
"Epoch: 43/200, Loss: 0.574, Accuracy: 0.781\n",
"Epoch: 44/200, Loss: 0.383, Accuracy: 0.844\n",
"Epoch: 45/200, Loss: 0.550, Accuracy: 0.750\n",
"Epoch: 46/200, Loss: 0.424, Accuracy: 0.844\n",
"Epoch: 47/200, Loss: 0.647, Accuracy: 0.625\n",
"Epoch: 48/200, Loss: 0.469, Accuracy: 0.719\n",
"Epoch: 49/200, Loss: 0.360, Accuracy: 0.781\n",
"Epoch: 50/200, Loss: 0.303, Accuracy: 0.938\n",
"Epoch: 51/200, Loss: 0.523, Accuracy: 0.781\n",
"Epoch: 52/200, Loss: 0.532, Accuracy: 0.719\n",
"Epoch: 53/200, Loss: 0.491, Accuracy: 0.812\n",
"Epoch: 54/200, Loss: 0.337, Accuracy: 0.812\n",
"Epoch: 55/200, Loss: 0.404, Accuracy: 0.781\n",
"Epoch: 56/200, Loss: 0.383, Accuracy: 0.875\n",
"Epoch: 57/200, Loss: 0.574, Accuracy: 0.688\n",
"Epoch: 58/200, Loss: 0.449, Accuracy: 0.781\n",
"Epoch: 59/200, Loss: 0.369, Accuracy: 0.875\n",
"Epoch: 60/200, Loss: 0.615, Accuracy: 0.812\n",
"Epoch: 61/200, Loss: 0.506, Accuracy: 0.781\n",
"Epoch: 62/200, Loss: 0.595, Accuracy: 0.750\n",
"Epoch: 63/200, Loss: 0.421, Accuracy: 0.781\n",
"Epoch: 64/200, Loss: 0.387, Accuracy: 0.812\n",
"Epoch: 65/200, Loss: 0.358, Accuracy: 0.844\n",
"Epoch: 66/200, Loss: 0.413, Accuracy: 0.812\n",
"Epoch: 67/200, Loss: 0.445, Accuracy: 0.719\n",
"Epoch: 68/200, Loss: 0.397, Accuracy: 0.750\n",
"Epoch: 69/200, Loss: 0.485, Accuracy: 0.781\n",
"Epoch: 70/200, Loss: 0.449, Accuracy: 0.750\n",
"Epoch: 71/200, Loss: 0.385, Accuracy: 0.812\n",
"Epoch: 72/200, Loss: 0.587, Accuracy: 0.688\n",
"Epoch: 73/200, Loss: 0.414, Accuracy: 0.906\n",
"Epoch: 74/200, Loss: 0.466, Accuracy: 0.844\n",
"Epoch: 75/200, Loss: 0.441, Accuracy: 0.875\n",
"Epoch: 76/200, Loss: 0.386, Accuracy: 0.781\n",
"Epoch: 77/200, Loss: 0.482, Accuracy: 0.781\n",
"Epoch: 78/200, Loss: 0.456, Accuracy: 0.812\n",
"Epoch: 79/200, Loss: 0.216, Accuracy: 0.969\n",
"Epoch: 80/200, Loss: 0.457, Accuracy: 0.719\n",
"Epoch: 81/200, Loss: 0.481, Accuracy: 0.750\n",
"Epoch: 82/200, Loss: 0.587, Accuracy: 0.719\n",
"Epoch: 83/200, Loss: 0.338, Accuracy: 0.844\n",
"Epoch: 84/200, Loss: 0.560, Accuracy: 0.688\n",
"Epoch: 85/200, Loss: 0.343, Accuracy: 0.875\n",
"Epoch: 86/200, Loss: 0.383, Accuracy: 0.844\n",
"Epoch: 87/200, Loss: 0.355, Accuracy: 0.812\n",
"Epoch: 88/200, Loss: 0.353, Accuracy: 0.875\n",
"Epoch: 89/200, Loss: 0.334, Accuracy: 0.812\n",
"Epoch: 90/200, Loss: 0.264, Accuracy: 0.875\n",
"Epoch: 91/200, Loss: 0.473, Accuracy: 0.781\n",
"Epoch: 92/200, Loss: 0.570, Accuracy: 0.750\n",
"Epoch: 93/200, Loss: 0.317, Accuracy: 0.844\n",
"Epoch: 94/200, Loss: 0.356, Accuracy: 0.781\n",
"Epoch: 95/200, Loss: 0.541, Accuracy: 0.812\n",
"Epoch: 96/200, Loss: 0.352, Accuracy: 0.812\n",
"Epoch: 97/200, Loss: 0.427, Accuracy: 0.750\n",
"Epoch: 98/200, Loss: 0.367, Accuracy: 0.812\n",
"Epoch: 99/200, Loss: 0.456, Accuracy: 0.719\n",
"Epoch: 100/200, Loss: 0.326, Accuracy: 0.781\n",
"Epoch: 101/200, Loss: 0.477, Accuracy: 0.719\n",
"Epoch: 102/200, Loss: 0.310, Accuracy: 0.844\n",
"Epoch: 103/200, Loss: 0.236, Accuracy: 0.906\n",
"Epoch: 104/200, Loss: 0.553, Accuracy: 0.625\n",
"Epoch: 105/200, Loss: 0.323, Accuracy: 0.906\n",
"Epoch: 106/200, Loss: 0.458, Accuracy: 0.688\n",
"Epoch: 107/200, Loss: 0.505, Accuracy: 0.719\n",
"Epoch: 108/200, Loss: 0.493, Accuracy: 0.781\n",
"Epoch: 109/200, Loss: 0.492, Accuracy: 0.688\n",
"Epoch: 110/200, Loss: 0.400, Accuracy: 0.719\n",
"Epoch: 111/200, Loss: 0.423, Accuracy: 0.844\n",
"Epoch: 112/200, Loss: 0.549, Accuracy: 0.750\n",
"Epoch: 113/200, Loss: 0.564, Accuracy: 0.750\n",
"Epoch: 114/200, Loss: 0.349, Accuracy: 0.844\n",
"Epoch: 115/200, Loss: 0.438, Accuracy: 0.812\n",
"Epoch: 116/200, Loss: 0.327, Accuracy: 0.812\n",
"Epoch: 117/200, Loss: 0.697, Accuracy: 0.688\n",
"Epoch: 118/200, Loss: 0.555, Accuracy: 0.781\n",
"Epoch: 119/200, Loss: 0.500, Accuracy: 0.750\n",
"Epoch: 120/200, Loss: 0.376, Accuracy: 0.812\n",
"Epoch: 121/200, Loss: 0.522, Accuracy: 0.750\n",
"Epoch: 122/200, Loss: 0.292, Accuracy: 0.906\n",
"Epoch: 123/200, Loss: 0.475, Accuracy: 0.750\n",
"Epoch: 124/200, Loss: 0.394, Accuracy: 0.812\n",
"Epoch: 125/200, Loss: 0.293, Accuracy: 0.812\n",
"Epoch: 126/200, Loss: 0.312, Accuracy: 0.875\n",
"Epoch: 127/200, Loss: 0.516, Accuracy: 0.781\n",
"Epoch: 128/200, Loss: 0.433, Accuracy: 0.750\n",
"Epoch: 129/200, Loss: 0.376, Accuracy: 0.812\n",
"Epoch: 130/200, Loss: 0.279, Accuracy: 0.844\n",
"Epoch: 131/200, Loss: 0.486, Accuracy: 0.750\n",
"Epoch: 132/200, Loss: 0.542, Accuracy: 0.562\n",
"Epoch: 133/200, Loss: 0.365, Accuracy: 0.812\n",
"Epoch: 134/200, Loss: 0.533, Accuracy: 0.844\n",
"Epoch: 135/200, Loss: 0.341, Accuracy: 0.844\n",
"Epoch: 136/200, Loss: 0.282, Accuracy: 0.906\n",
"Epoch: 137/200, Loss: 0.490, Accuracy: 0.719\n",
"Epoch: 138/200, Loss: 0.481, Accuracy: 0.719\n",
"Epoch: 139/200, Loss: 0.435, Accuracy: 0.812\n",
"Epoch: 140/200, Loss: 0.622, Accuracy: 0.750\n",
"Epoch: 141/200, Loss: 0.346, Accuracy: 0.781\n",
"Epoch: 142/200, Loss: 0.258, Accuracy: 0.938\n",
"Epoch: 143/200, Loss: 0.407, Accuracy: 0.781\n",
"Epoch: 144/200, Loss: 0.461, Accuracy: 0.750\n",
"Epoch: 145/200, Loss: 0.350, Accuracy: 0.812\n",
"Epoch: 146/200, Loss: 0.399, Accuracy: 0.844\n",
"Epoch: 147/200, Loss: 0.358, Accuracy: 0.781\n",
"Epoch: 148/200, Loss: 0.357, Accuracy: 0.781\n",
"Epoch: 149/200, Loss: 0.426, Accuracy: 0.781\n",
"Epoch: 150/200, Loss: 0.349, Accuracy: 0.844\n",
"Epoch: 151/200, Loss: 0.619, Accuracy: 0.719\n",
"Epoch: 152/200, Loss: 0.418, Accuracy: 0.781\n",
"Epoch: 153/200, Loss: 0.604, Accuracy: 0.625\n",
"Epoch: 154/200, Loss: 0.397, Accuracy: 0.781\n",
"Epoch: 155/200, Loss: 0.366, Accuracy: 0.875\n",
"Epoch: 156/200, Loss: 0.309, Accuracy: 0.875\n",
"Epoch: 157/200, Loss: 0.337, Accuracy: 0.812\n",
"Epoch: 158/200, Loss: 0.404, Accuracy: 0.875\n",
"Epoch: 159/200, Loss: 0.362, Accuracy: 0.844\n",
"Epoch: 160/200, Loss: 0.400, Accuracy: 0.812\n",
"Epoch: 161/200, Loss: 0.422, Accuracy: 0.812\n",
"Epoch: 162/200, Loss: 0.290, Accuracy: 0.844\n",
"Epoch: 163/200, Loss: 0.360, Accuracy: 0.781\n",
"Epoch: 164/200, Loss: 0.552, Accuracy: 0.719\n",
"Epoch: 165/200, Loss: 0.438, Accuracy: 0.781\n",
"Epoch: 166/200, Loss: 0.429, Accuracy: 0.781\n",
"Epoch: 167/200, Loss: 0.322, Accuracy: 0.938\n",
"Epoch: 168/200, Loss: 0.383, Accuracy: 0.812\n",
"Epoch: 169/200, Loss: 0.484, Accuracy: 0.656\n",
"Epoch: 170/200, Loss: 0.398, Accuracy: 0.875\n",
"Epoch: 171/200, Loss: 0.296, Accuracy: 0.906\n",
"Epoch: 172/200, Loss: 0.491, Accuracy: 0.719\n",
"Epoch: 173/200, Loss: 0.729, Accuracy: 0.688\n",
"Epoch: 174/200, Loss: 0.481, Accuracy: 0.812\n",
"Epoch: 175/200, Loss: 0.432, Accuracy: 0.688\n",
"Epoch: 176/200, Loss: 0.344, Accuracy: 0.781\n",
"Epoch: 177/200, Loss: 0.284, Accuracy: 0.906\n",
"Epoch: 178/200, Loss: 0.452, Accuracy: 0.750\n",
"Epoch: 179/200, Loss: 0.419, Accuracy: 0.781\n",
"Epoch: 180/200, Loss: 0.505, Accuracy: 0.750\n",
"Epoch: 181/200, Loss: 0.372, Accuracy: 0.844\n",
"Epoch: 182/200, Loss: 0.429, Accuracy: 0.750\n",
"Epoch: 183/200, Loss: 0.579, Accuracy: 0.688\n",
"Epoch: 184/200, Loss: 0.386, Accuracy: 0.812\n",
"Epoch: 185/200, Loss: 0.359, Accuracy: 0.812\n",
"Epoch: 186/200, Loss: 0.402, Accuracy: 0.781\n",
"Epoch: 187/200, Loss: 0.542, Accuracy: 0.688\n",
"Epoch: 188/200, Loss: 0.388, Accuracy: 0.875\n",
"Epoch: 189/200, Loss: 0.410, Accuracy: 0.875\n",
"Epoch: 190/200, Loss: 0.344, Accuracy: 0.812\n",
"Epoch: 191/200, Loss: 0.484, Accuracy: 0.844\n",
"Epoch: 192/200, Loss: 0.480, Accuracy: 0.719\n",
"Epoch: 193/200, Loss: 0.560, Accuracy: 0.750\n",
"Epoch: 194/200, Loss: 0.466, Accuracy: 0.750\n",
"Epoch: 195/200, Loss: 0.238, Accuracy: 0.938\n",
"Epoch: 196/200, Loss: 0.423, Accuracy: 0.844\n",
"Epoch: 197/200, Loss: 0.366, Accuracy: 0.812\n",
"Epoch: 198/200, Loss: 0.356, Accuracy: 0.875\n",
"Epoch: 199/200, Loss: 0.373, Accuracy: 0.812\n",
"Epoch: 200/200, Loss: 0.473, Accuracy: 0.781\n"
]
}
],
"source": [
"# Training the network\n",
"epochs = 200\n",
"for epoch in range(epochs):\n",
" for inputs, labels in train_loader:\n",
" inputs = inputs.float()\n",
" labels = labels.float()\n",
" # Forward propagation\n",
" outputs = net(inputs)\n",
" # Loss calculation\n",
" loss = criterion(outputs,labels)\n",
" # Clear gradient buffer (w <- w - lr*gradient)\n",
" optimizer.zero_grad()\n",
" # Back propagation\n",
" loss.backward()\n",
" # Update weights\n",
" optimizer.step()\n",
" \n",
" # Accuracy caalculation\n",
" output = (outputs > 0.5).float()\n",
" accuracy = (output == labels).float().mean()\n",
" \n",
" # Print statistics\n",
" print(\"Epoch: {}/{}, Loss: {:.3f}, Accuracy: {:.3f}\".format(epoch+1, epochs, loss, accuracy))\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}