{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/eddie/.pyenv/versions/3.7.6/envs/pytorch/lib/python3.7/site-packages/pandas/compat/__init__.py:117: UserWarning: Could not import the lzma module. Your installed Python is incomplete. Attempting to use lzma compression will result in a RuntimeError.\n", " warnings.warn(msg)\n" ] } ], "source": [ "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import pandas as pd\n", "from sklearn.preprocessing import StandardScaler\n", "from torch.utils.data import Dataset\n", "from torch.utils.data import DataLoader\n", "from torch.nn import BCELoss\n", "from torch.optim import SGD" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Load the data set using pandas\n", "data = pd.read_csv('diabetes.csv')\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Number of times pregnantPlasma glucose concentrationDiastolic blood pressureTriceps skin fold thickness2-Hour serum insulinBody mass indexAgeClass
061487235033.650positive
11856629026.631negative
28183640023.332positive
318966239428.121negative
40137403516843.133positive
\n", "
" ], "text/plain": [ " Number of times pregnant Plasma glucose concentration \\\n", "0 6 148 \n", "1 1 85 \n", "2 8 183 \n", "3 1 89 \n", "4 0 137 \n", "\n", " Diastolic blood pressure Triceps skin fold thickness \\\n", "0 72 35 \n", "1 66 29 \n", "2 64 0 \n", "3 66 23 \n", "4 40 35 \n", "\n", " 2-Hour serum insulin Body mass index Age Class \n", "0 0 33.6 50 positive \n", "1 0 26.6 31 negative \n", "2 0 23.3 32 positive \n", "3 94 28.1 21 negative \n", "4 168 43.1 33 positive " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.head() " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "x = data.iloc[:,0:-1].values" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(768, 7)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.shape" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "y_string = list(data.iloc[:,-1])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "768" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(y_string)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "y_int = []\n", "for i in y_string:\n", " if i == 'positive':\n", " y_int.append(1)\n", " else:\n", " y_int.append(0)\n", " " ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "y = np.array(y_int, dtype='float64') " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# data normaalization\n", "sc = StandardScaler()\n", "x = sc.fit_transform(x)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "x = torch.tensor(x)\n", "y = torch.tensor(y)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([768, 7])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x.shape" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "y = y.unsqueeze(1)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([768, 1])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y.shape" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "class Dataset(Dataset):\n", " def __init__(self,x,y):\n", " self.x = x\n", " self.y = y\n", " \n", " def __getitem__(self, index):\n", " return self.x[index], self.y[index]\n", " \n", " def __len__(self):\n", " return len(self.x)\n", " " ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "dataset = Dataset(x,y)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "768" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(dataset)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "train_loader = DataLoader(dataset=dataset,\n", " batch_size=32,\n", " shuffle=True)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_loader" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "There are 24 batches in in the dataset\n", "For one iteration (baatch) there are:\n", "Data: torch.Size([32, 7])\n", "lables: torch.Size([32, 1])\n" ] } ], "source": [ "# visualization of the train loader\n", "print(\"There are {} batches in in the dataset\".format(len(train_loader)))\n", "for (x,y) in train_loader:\n", " print(\"For one iteration (baatch) there are:\")\n", " print(\"Data: {}\".format(x.shape))\n", " print(\"labels: {}\".format(y.shape))\n", " break\n", " " ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "24.0" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "768/32" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "class Model(nn.Module):\n", " def __init__(self, input_features,output_features): #,hidden_layer_1,hidden_layer_2,): caan be done this way or hard coded\n", " super(Model, self).__init__()\n", " self.fc1 = nn.Linear(input_features, 5)\n", " self.fc2 = nn.Linear(5,4)\n", " self.fc3 = nn.Linear(4,3)\n", " self.fc4 = nn.Linear(3,output_features)\n", " self.sigmoid = nn.Sigmoid()\n", " self.tanh = nn.Tanh()\n", " \n", " def forward(self,x):\n", " out = self.fc1(x)\n", " out = self.tanh(out)\n", " out = self.fc2(out)\n", " out = self.tanh(out)\n", " out = self.fc3(out)\n", " out = self.tanh(out)\n", " out = self.fc4(out)\n", " out = self.sigmoid(out)\n", " \n", " return out\n", " " ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "# The definition of the class Model object nets\n", "net = Model(7,1)\n", "\n", "# Binary cross entropy was chosen as a the loss function\n", "criterion = BCELoss(reduction='mean')\n", "\n", "# Define the optimizer\n", "optimizer = SGD(net.parameters(), lr=0.1, momentum=0.9)\n", "\n" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 1/200, Loss: 0.551, Accuracy: 0.719\n", "Epoch: 2/200, Loss: 0.473, Accuracy: 0.781\n", "Epoch: 3/200, Loss: 0.536, Accuracy: 0.750\n", "Epoch: 4/200, Loss: 0.519, Accuracy: 0.656\n", "Epoch: 5/200, Loss: 0.453, Accuracy: 0.781\n", "Epoch: 6/200, Loss: 0.666, Accuracy: 0.594\n", "Epoch: 7/200, Loss: 0.455, Accuracy: 0.719\n", "Epoch: 8/200, Loss: 0.617, Accuracy: 0.656\n", "Epoch: 9/200, Loss: 0.367, Accuracy: 0.875\n", "Epoch: 10/200, Loss: 0.496, Accuracy: 0.719\n", "Epoch: 11/200, Loss: 0.483, Accuracy: 0.812\n", "Epoch: 12/200, Loss: 0.556, Accuracy: 0.656\n", "Epoch: 13/200, Loss: 0.455, Accuracy: 0.719\n", "Epoch: 14/200, Loss: 0.504, Accuracy: 0.781\n", "Epoch: 15/200, Loss: 0.492, Accuracy: 0.750\n", "Epoch: 16/200, Loss: 0.603, Accuracy: 0.781\n", "Epoch: 17/200, Loss: 0.417, Accuracy: 0.906\n", "Epoch: 18/200, Loss: 0.450, Accuracy: 0.844\n", "Epoch: 19/200, Loss: 0.313, Accuracy: 0.844\n", "Epoch: 20/200, Loss: 0.642, Accuracy: 0.656\n", "Epoch: 21/200, Loss: 0.369, Accuracy: 0.844\n", "Epoch: 22/200, Loss: 0.499, Accuracy: 0.781\n", "Epoch: 23/200, Loss: 0.477, Accuracy: 0.781\n", "Epoch: 24/200, Loss: 0.440, Accuracy: 0.750\n", "Epoch: 25/200, Loss: 0.523, Accuracy: 0.781\n", "Epoch: 26/200, Loss: 0.362, Accuracy: 0.906\n", "Epoch: 27/200, Loss: 0.424, Accuracy: 0.875\n", "Epoch: 28/200, Loss: 0.389, Accuracy: 0.781\n", "Epoch: 29/200, Loss: 0.534, Accuracy: 0.688\n", "Epoch: 30/200, Loss: 0.435, Accuracy: 0.844\n", "Epoch: 31/200, Loss: 0.315, Accuracy: 0.875\n", "Epoch: 32/200, Loss: 0.574, Accuracy: 0.688\n", "Epoch: 33/200, Loss: 0.485, Accuracy: 0.750\n", "Epoch: 34/200, Loss: 0.373, Accuracy: 0.781\n", "Epoch: 35/200, Loss: 0.406, Accuracy: 0.844\n", "Epoch: 36/200, Loss: 0.497, Accuracy: 0.781\n", "Epoch: 37/200, Loss: 0.276, Accuracy: 0.969\n", "Epoch: 38/200, Loss: 0.533, Accuracy: 0.719\n", "Epoch: 39/200, Loss: 0.465, Accuracy: 0.781\n", "Epoch: 40/200, Loss: 0.435, Accuracy: 0.750\n", "Epoch: 41/200, Loss: 0.435, Accuracy: 0.812\n", "Epoch: 42/200, Loss: 0.347, Accuracy: 0.812\n", "Epoch: 43/200, Loss: 0.574, Accuracy: 0.781\n", "Epoch: 44/200, Loss: 0.383, Accuracy: 0.844\n", "Epoch: 45/200, Loss: 0.550, Accuracy: 0.750\n", "Epoch: 46/200, Loss: 0.424, Accuracy: 0.844\n", "Epoch: 47/200, Loss: 0.647, Accuracy: 0.625\n", "Epoch: 48/200, Loss: 0.469, Accuracy: 0.719\n", "Epoch: 49/200, Loss: 0.360, Accuracy: 0.781\n", "Epoch: 50/200, Loss: 0.303, Accuracy: 0.938\n", "Epoch: 51/200, Loss: 0.523, Accuracy: 0.781\n", "Epoch: 52/200, Loss: 0.532, Accuracy: 0.719\n", "Epoch: 53/200, Loss: 0.491, Accuracy: 0.812\n", "Epoch: 54/200, Loss: 0.337, Accuracy: 0.812\n", "Epoch: 55/200, Loss: 0.404, Accuracy: 0.781\n", "Epoch: 56/200, Loss: 0.383, Accuracy: 0.875\n", "Epoch: 57/200, Loss: 0.574, Accuracy: 0.688\n", "Epoch: 58/200, Loss: 0.449, Accuracy: 0.781\n", "Epoch: 59/200, Loss: 0.369, Accuracy: 0.875\n", "Epoch: 60/200, Loss: 0.615, Accuracy: 0.812\n", "Epoch: 61/200, Loss: 0.506, Accuracy: 0.781\n", "Epoch: 62/200, Loss: 0.595, Accuracy: 0.750\n", "Epoch: 63/200, Loss: 0.421, Accuracy: 0.781\n", "Epoch: 64/200, Loss: 0.387, Accuracy: 0.812\n", "Epoch: 65/200, Loss: 0.358, Accuracy: 0.844\n", "Epoch: 66/200, Loss: 0.413, Accuracy: 0.812\n", "Epoch: 67/200, Loss: 0.445, Accuracy: 0.719\n", "Epoch: 68/200, Loss: 0.397, Accuracy: 0.750\n", "Epoch: 69/200, Loss: 0.485, Accuracy: 0.781\n", "Epoch: 70/200, Loss: 0.449, Accuracy: 0.750\n", "Epoch: 71/200, Loss: 0.385, Accuracy: 0.812\n", "Epoch: 72/200, Loss: 0.587, Accuracy: 0.688\n", "Epoch: 73/200, Loss: 0.414, Accuracy: 0.906\n", "Epoch: 74/200, Loss: 0.466, Accuracy: 0.844\n", "Epoch: 75/200, Loss: 0.441, Accuracy: 0.875\n", "Epoch: 76/200, Loss: 0.386, Accuracy: 0.781\n", "Epoch: 77/200, Loss: 0.482, Accuracy: 0.781\n", "Epoch: 78/200, Loss: 0.456, Accuracy: 0.812\n", "Epoch: 79/200, Loss: 0.216, Accuracy: 0.969\n", "Epoch: 80/200, Loss: 0.457, Accuracy: 0.719\n", "Epoch: 81/200, Loss: 0.481, Accuracy: 0.750\n", "Epoch: 82/200, Loss: 0.587, Accuracy: 0.719\n", "Epoch: 83/200, Loss: 0.338, Accuracy: 0.844\n", "Epoch: 84/200, Loss: 0.560, Accuracy: 0.688\n", "Epoch: 85/200, Loss: 0.343, Accuracy: 0.875\n", "Epoch: 86/200, Loss: 0.383, Accuracy: 0.844\n", "Epoch: 87/200, Loss: 0.355, Accuracy: 0.812\n", "Epoch: 88/200, Loss: 0.353, Accuracy: 0.875\n", "Epoch: 89/200, Loss: 0.334, Accuracy: 0.812\n", "Epoch: 90/200, Loss: 0.264, Accuracy: 0.875\n", "Epoch: 91/200, Loss: 0.473, Accuracy: 0.781\n", "Epoch: 92/200, Loss: 0.570, Accuracy: 0.750\n", "Epoch: 93/200, Loss: 0.317, Accuracy: 0.844\n", "Epoch: 94/200, Loss: 0.356, Accuracy: 0.781\n", "Epoch: 95/200, Loss: 0.541, Accuracy: 0.812\n", "Epoch: 96/200, Loss: 0.352, Accuracy: 0.812\n", "Epoch: 97/200, Loss: 0.427, Accuracy: 0.750\n", "Epoch: 98/200, Loss: 0.367, Accuracy: 0.812\n", "Epoch: 99/200, Loss: 0.456, Accuracy: 0.719\n", "Epoch: 100/200, Loss: 0.326, Accuracy: 0.781\n", "Epoch: 101/200, Loss: 0.477, Accuracy: 0.719\n", "Epoch: 102/200, Loss: 0.310, Accuracy: 0.844\n", "Epoch: 103/200, Loss: 0.236, Accuracy: 0.906\n", "Epoch: 104/200, Loss: 0.553, Accuracy: 0.625\n", "Epoch: 105/200, Loss: 0.323, Accuracy: 0.906\n", "Epoch: 106/200, Loss: 0.458, Accuracy: 0.688\n", "Epoch: 107/200, Loss: 0.505, Accuracy: 0.719\n", "Epoch: 108/200, Loss: 0.493, Accuracy: 0.781\n", "Epoch: 109/200, Loss: 0.492, Accuracy: 0.688\n", "Epoch: 110/200, Loss: 0.400, Accuracy: 0.719\n", "Epoch: 111/200, Loss: 0.423, Accuracy: 0.844\n", "Epoch: 112/200, Loss: 0.549, Accuracy: 0.750\n", "Epoch: 113/200, Loss: 0.564, Accuracy: 0.750\n", "Epoch: 114/200, Loss: 0.349, Accuracy: 0.844\n", "Epoch: 115/200, Loss: 0.438, Accuracy: 0.812\n", "Epoch: 116/200, Loss: 0.327, Accuracy: 0.812\n", "Epoch: 117/200, Loss: 0.697, Accuracy: 0.688\n", "Epoch: 118/200, Loss: 0.555, Accuracy: 0.781\n", "Epoch: 119/200, Loss: 0.500, Accuracy: 0.750\n", "Epoch: 120/200, Loss: 0.376, Accuracy: 0.812\n", "Epoch: 121/200, Loss: 0.522, Accuracy: 0.750\n", "Epoch: 122/200, Loss: 0.292, Accuracy: 0.906\n", "Epoch: 123/200, Loss: 0.475, Accuracy: 0.750\n", "Epoch: 124/200, Loss: 0.394, Accuracy: 0.812\n", "Epoch: 125/200, Loss: 0.293, Accuracy: 0.812\n", "Epoch: 126/200, Loss: 0.312, Accuracy: 0.875\n", "Epoch: 127/200, Loss: 0.516, Accuracy: 0.781\n", "Epoch: 128/200, Loss: 0.433, Accuracy: 0.750\n", "Epoch: 129/200, Loss: 0.376, Accuracy: 0.812\n", "Epoch: 130/200, Loss: 0.279, Accuracy: 0.844\n", "Epoch: 131/200, Loss: 0.486, Accuracy: 0.750\n", "Epoch: 132/200, Loss: 0.542, Accuracy: 0.562\n", "Epoch: 133/200, Loss: 0.365, Accuracy: 0.812\n", "Epoch: 134/200, Loss: 0.533, Accuracy: 0.844\n", "Epoch: 135/200, Loss: 0.341, Accuracy: 0.844\n", "Epoch: 136/200, Loss: 0.282, Accuracy: 0.906\n", "Epoch: 137/200, Loss: 0.490, Accuracy: 0.719\n", "Epoch: 138/200, Loss: 0.481, Accuracy: 0.719\n", "Epoch: 139/200, Loss: 0.435, Accuracy: 0.812\n", "Epoch: 140/200, Loss: 0.622, Accuracy: 0.750\n", "Epoch: 141/200, Loss: 0.346, Accuracy: 0.781\n", "Epoch: 142/200, Loss: 0.258, Accuracy: 0.938\n", "Epoch: 143/200, Loss: 0.407, Accuracy: 0.781\n", "Epoch: 144/200, Loss: 0.461, Accuracy: 0.750\n", "Epoch: 145/200, Loss: 0.350, Accuracy: 0.812\n", "Epoch: 146/200, Loss: 0.399, Accuracy: 0.844\n", "Epoch: 147/200, Loss: 0.358, Accuracy: 0.781\n", "Epoch: 148/200, Loss: 0.357, Accuracy: 0.781\n", "Epoch: 149/200, Loss: 0.426, Accuracy: 0.781\n", "Epoch: 150/200, Loss: 0.349, Accuracy: 0.844\n", "Epoch: 151/200, Loss: 0.619, Accuracy: 0.719\n", "Epoch: 152/200, Loss: 0.418, Accuracy: 0.781\n", "Epoch: 153/200, Loss: 0.604, Accuracy: 0.625\n", "Epoch: 154/200, Loss: 0.397, Accuracy: 0.781\n", "Epoch: 155/200, Loss: 0.366, Accuracy: 0.875\n", "Epoch: 156/200, Loss: 0.309, Accuracy: 0.875\n", "Epoch: 157/200, Loss: 0.337, Accuracy: 0.812\n", "Epoch: 158/200, Loss: 0.404, Accuracy: 0.875\n", "Epoch: 159/200, Loss: 0.362, Accuracy: 0.844\n", "Epoch: 160/200, Loss: 0.400, Accuracy: 0.812\n", "Epoch: 161/200, Loss: 0.422, Accuracy: 0.812\n", "Epoch: 162/200, Loss: 0.290, Accuracy: 0.844\n", "Epoch: 163/200, Loss: 0.360, Accuracy: 0.781\n", "Epoch: 164/200, Loss: 0.552, Accuracy: 0.719\n", "Epoch: 165/200, Loss: 0.438, Accuracy: 0.781\n", "Epoch: 166/200, Loss: 0.429, Accuracy: 0.781\n", "Epoch: 167/200, Loss: 0.322, Accuracy: 0.938\n", "Epoch: 168/200, Loss: 0.383, Accuracy: 0.812\n", "Epoch: 169/200, Loss: 0.484, Accuracy: 0.656\n", "Epoch: 170/200, Loss: 0.398, Accuracy: 0.875\n", "Epoch: 171/200, Loss: 0.296, Accuracy: 0.906\n", "Epoch: 172/200, Loss: 0.491, Accuracy: 0.719\n", "Epoch: 173/200, Loss: 0.729, Accuracy: 0.688\n", "Epoch: 174/200, Loss: 0.481, Accuracy: 0.812\n", "Epoch: 175/200, Loss: 0.432, Accuracy: 0.688\n", "Epoch: 176/200, Loss: 0.344, Accuracy: 0.781\n", "Epoch: 177/200, Loss: 0.284, Accuracy: 0.906\n", "Epoch: 178/200, Loss: 0.452, Accuracy: 0.750\n", "Epoch: 179/200, Loss: 0.419, Accuracy: 0.781\n", "Epoch: 180/200, Loss: 0.505, Accuracy: 0.750\n", "Epoch: 181/200, Loss: 0.372, Accuracy: 0.844\n", "Epoch: 182/200, Loss: 0.429, Accuracy: 0.750\n", "Epoch: 183/200, Loss: 0.579, Accuracy: 0.688\n", "Epoch: 184/200, Loss: 0.386, Accuracy: 0.812\n", "Epoch: 185/200, Loss: 0.359, Accuracy: 0.812\n", "Epoch: 186/200, Loss: 0.402, Accuracy: 0.781\n", "Epoch: 187/200, Loss: 0.542, Accuracy: 0.688\n", "Epoch: 188/200, Loss: 0.388, Accuracy: 0.875\n", "Epoch: 189/200, Loss: 0.410, Accuracy: 0.875\n", "Epoch: 190/200, Loss: 0.344, Accuracy: 0.812\n", "Epoch: 191/200, Loss: 0.484, Accuracy: 0.844\n", "Epoch: 192/200, Loss: 0.480, Accuracy: 0.719\n", "Epoch: 193/200, Loss: 0.560, Accuracy: 0.750\n", "Epoch: 194/200, Loss: 0.466, Accuracy: 0.750\n", "Epoch: 195/200, Loss: 0.238, Accuracy: 0.938\n", "Epoch: 196/200, Loss: 0.423, Accuracy: 0.844\n", "Epoch: 197/200, Loss: 0.366, Accuracy: 0.812\n", "Epoch: 198/200, Loss: 0.356, Accuracy: 0.875\n", "Epoch: 199/200, Loss: 0.373, Accuracy: 0.812\n", "Epoch: 200/200, Loss: 0.473, Accuracy: 0.781\n" ] } ], "source": [ "# Training the network\n", "epochs = 200\n", "for epoch in range(epochs):\n", " for inputs, labels in train_loader:\n", " inputs = inputs.float()\n", " labels = labels.float()\n", " # Forward propagation\n", " outputs = net(inputs)\n", " # Loss calculation\n", " loss = criterion(outputs,labels)\n", " # Clear gradient buffer (w <- w - lr*gradient)\n", " optimizer.zero_grad()\n", " # Back propagation\n", " loss.backward()\n", " # Update weights\n", " optimizer.step()\n", " \n", " # Accuracy caalculation\n", " output = (outputs > 0.5).float()\n", " accuracy = (output == labels).float().mean()\n", " \n", " # Print statistics\n", " print(\"Epoch: {}/{}, Loss: {:.3f}, Accuracy: {:.3f}\".format(epoch+1, epochs, loss, accuracy))\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }