pytorch-stuff/Numpy_NN.ipynb

295 lines
42 KiB
Plaintext
Raw Normal View History

2020-04-27 19:53:55 +00:00
{
"cells": [
{
"cell_type": "code",
2020-04-28 18:56:54 +00:00
"execution_count": 1,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [],
"source": [
2020-04-28 16:42:58 +00:00
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
2020-04-28 18:56:54 +00:00
"import sklearn.datasets\n"
2020-04-27 19:53:55 +00:00
]
},
{
"cell_type": "code",
2020-04-28 18:56:54 +00:00
"execution_count": 2,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [],
"source": [
2020-04-28 18:56:54 +00:00
"X,y = sklearn.datasets.make_moons(200, noise = 0.15)\n"
2020-04-27 19:53:55 +00:00
]
},
{
"cell_type": "code",
2020-04-28 18:56:54 +00:00
"execution_count": 3,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-28 18:56:54 +00:00
"<matplotlib.collections.PathCollection at 0x7f0d41ee6e10>"
2020-04-27 19:53:55 +00:00
]
},
2020-04-28 18:56:54 +00:00
"execution_count": 3,
2020-04-27 19:53:55 +00:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2020-04-28 18:56:54 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nOydZ5QURReGn+rJs4ldspJFURRRQAmSM4igoCQlqYgo8mFCxQgCgoEkYkRFBUEUFAUlSQaRnHOSDAssmyZ3fT9mGXaY2TzLpn7O4Sg13VXVQ8/t6lv3vldIKdHQ0NDQKPwoeT0BDQ0NDY3rg2bwNTQ0NIoImsHX0NDQKCJoBl9DQ0OjiKAZfA0NDY0igj6vJ5AWJUqUkJUqVcrraWhoaGgUKDZt2hQrpSwZ7LN8a/ArVarExo0b83oaGhoaGgUKIcSxtD7TXDoaGhoaRQTN4GtoaGgUETSDr6GhoVFE0Ay+hoaGRhFBM/gaBZLTh8+yfeVu4i8m5PVUNDQKDPk2SkdDIxhJl5N4p8uH7F63H4NRj8vh4qH/3c8To3sihMjr6Wlo5Gu0Fb5GgeL9vp+wa81enDYnSZeTcdpd/Db5T5Z8vzKvp6ahke/RDL5GgSExLokNf23B5XD7tduTHMz+aF4ezUpDo+CgGXyNAkNyfDKKEvyWjb+g+fI1NDJCM/gaBYYS5YpjjbIGtCs6hdqta+bBjDQ0ChaawdcoMCiKwpBPn8JkMfo2aA1GPWFRVvq80zWPZ6ehkf/RonQ0ChQNOt3DuJUjmP3R75w+dIY7m9xOl+c7ULxsdF5PTUMj3yPya03bOnXqSE08TUNDQyNrCCE2SSnrBPtMc+loaGhoFBE0g6+hoaFRRNAMvoaGhkYRQTP4GhoaGkUEzeBraGhoFBE0g6+hoaFRRNAMvoaGhkYRQTP4GhqpsCXZib+QQH7NT9HQyAlapm0BwOPxcPLAGcKLWYkpo2WU5gaJcUl8+MQU1s/fDECpCiV4aeoz1Gh0Wx7PTEMjdGgGP5+zeu56xg/4HKfdhcfloXr9W3hj1vMUKxmV11MrVAxrP5oDmw/jdnqll08dPMOwdqP4fNuH3HBTmTyenYZGaNBcOvmYg1uPMKbXJOJjE7An2nE5XOxcs5dh7Ufn9dTyNVJK5n68gO7lB9De0pPn6g9j97p9aR5/ZMcxDm8/5jP2V3C53MydtCDoOUnxySz4ainTR/7Clr93aC4gjQKBtsLPx8yZOB+X3eXX5nF5+G/PSY7s/I/Kd1TIo5nlb759ayZzxs/HnuwAYO/6AwxtNYIJq0ZS9e7KAcefPnIOnUEX0O5xeTi+71RA+4HNh3m5xXA8bg8OmxOT1US1Ojfx3l+vYzAaQn9BGhohQlvh52POHj2PqgauHHUGHRdOXcqDGeV/7MkOfhn3h8/YX8Fpc/Ld8J+CnlP1rkq4HK6AdqPZwB0Nb/Vrk1Iy4pGPSLqcjD3JgVQl9kQ7O1fvZfZHv4fuQjQ0cgHN4OdjarWsgdEcuGJ0O1xUvbvS9Z9QAeD88VgUXeBtLSUc2no06DmlKpSkSdcGmKwmX5uiU7CEW3jg6dZ+x548cJpLZy8H9OFxe/h++GwcNkfAZxoa+QXN4OdjHhjYhoiYcPSp3A3mMBMPDm6fq5u2tiQ7P4/7nSGN3uDNjmPYuGhbro0VaorfEIPH7Qn6WYXbbkzzvJemDqTviG6UqVSKyBIRNOvRkCmbxhJVItJ3jC3RxpEd/6F61KB9eFwels1cm7ML0NDIRTQ9/HzOpXOXmTlmLuvmbSSyeDhdnn+Apt0a+Co+hRp7soNBdV/jzOGzOGxOAMxWE91fe4hHX++SK2OGmk/+9zV/Tl2KI9npazNZjLy/5C2q16+WqT42L9nO9+/O5szhc9xS5yZKlS/Bn1OXougVbAn2NM9r2asxr0x7LsfXoKGRXdLTw9c2bfM50aWiGDiuLwPH9b0u4y2etpwzR875jD14HwIzRv1ChwGt/Fa8+ZWnx/UhLMrKnInzsSc6uKFqGZ6d2C/Txn7ZrDV89MQU3wMj9uTFTJ2nN+ooU7lUhsc57U6Wz1rL9pW7KVulFG36NafEDTGZGkNDIydoBl/Dj3W/b8SRHOiH1hv17PnnAPU61M6DWaWNlJJda/exafE2IoqF07R7A2LKRNN3RHf6DO+Gx+1Bb8j8ba6qKp8O+dbv7SCz6A162j/ZMt1jEuOSeK7ea8SevIg9yYHRbGDmmN8Yu+iNTD+QNDSyi2bwNfyILl0MoQjkNdFBUpVEFg8P2Tgej4eVs//h7xmrMJoNtH2iBXVa18ySq0pVVUb1mMC/CzZjT3ZgNBn4+vUZvPXzS9zb7m6EEOka+/MnLrB6znpUj0qDTvdQtkppEi4mkhiXmKnxhSLQ63UoeoWoEpG8+v1gSpYrnu45M0bP4eyx87gc3ph/p90FuBjT62OmHfg411x1GhqgGXyNa+j4bFtW/LTWz6UjFEFUyUhuq3dLhucnJ9j4549NOJId1G5dk1LlSwQco6oqb3Ycy46Vu7Ened8m/v1zCx0GtGbAh739jj207ShzJsznzNFz3N2iBh2faUNkTAQAq35Z7zX2KX04U3IWRvUYz+yzUzGa0o6J/3PqUiY/NxXwviV8/foM+ozozkOD2yGUzMUyRESH8enmD/C4PZSpVCpTxnrl7HU+Y5+aC6cvcf54LKUqlMzU2Boa2SEkBl8I8TXQATgnpbwjyOcCmAi0B5KBvlLKzaEYWyO0VKtzE4MmP8Eng79G0SmoHpUSN8Ywav4wP4O2f9Mhvh8+myM7/6PS7eXp9XZXki4n8/aDY31vCKpH5dE3utBzmP9m76ZF29ixao/PUAPYkxzMm/IXHZ9pQ9kqpQFYO28Do3tOwGV3oaqSvesP8NvHf9Lr7UeIKhnJwm+W+fWRmp2r9lCr5Z1BP4s9eYHJz031PSCuMO2tWdTrUJv7+7dk3qcL8biCR/sAmKwmnp30eNAHWlpcjo0nMS4p6GdSVTGk84DS0AgFoVrhfwtMBr5L4/N2wM0pf+oCn6b8VyMXSYxLQtEpWCMsWTqvbb/mNO12H/s3HiIsykqVOyv6GfvtK3czrP0onDYnUsK5Y+fZ8vcOQOC0+fu+Z4yeQ62Wd3LrvTf72tbP34Q9MTDSRSgKm5ds5/6nWuHxeBjX/zM/X7rT7sJpd/HxoKkZXoNQ0l5tr/l1AwRZjXvcHlb8tJb6HeukKamAgIrVyzPk0/7c0TDzwmr2ZAeD7n0NW5DrVnQKVe+uQnTpYpnuT0MjO4TE4EspVwohKqVzSCfgO+mNAf1HCFFMCFFWSnk6FOMXNs79d56PB01l48KtKHodTbs2YOD4voQXC8vU+cd2H2ds78kc3nEMAdzR8FaGTnsuQ/9yasxWE3c2rh70sylDvvEzxFKC0+YKamSddheLpi2n0h0V0Bt06A16X26B+5oVtKIThEVZATh9+FzQzePMoChKQIZsaqSU3kkHaZdS8unz36Z5riXczFPv98qSsQdYPnMNcecvB43hL1YqijdmDslSfxoa2eF6JV7dCBxP9fcTKW1+CCGeEkJsFEJsPH/+/HWaWv7ClmhjUN3X+PfPLbhdHpw2J3//uJqXmr+TKYGuxLgkhjR8k4NbDuNxeXC7PGxfuYchDd9IMyEpqxzZ8V/Q9ms3eq+0LZ+1lk5RvXkgohejek6gwYN1UfSB2jVCCOo94A0f1umVLEXK6PQK5jAT5jATb//yUrqaNg06Bg1RRm/U0/Che9O8PvBuatduHdxVlB47V+8J6n4ymAz0fucRzXevcV3IV5m2UsovpJR1pJR1SpYsmj+ApdNXY0u0+60E3U43pw6eYfuK3RmfP2MVLqfbbwGrelQSLiXy759bQjLHiOjMvWlcIeFiIqpHxe10s2bOeiY98yVDv3kWc5gJa6QFa4SFyOLhvPfn65hT5A3mTJif6f4NJgOt+zRj4Ph+zPjvM+5qFrCN5EepCiXp//5jGM0G9AYdOr0Ok8VIt6GdqFyjYpouML1Rz4RV76LTBT6sMqLcLTcElcnQG3WUrVw6y/1
2020-04-27 19:53:55 +00:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
2020-04-28 18:56:54 +00:00
"plt.scatter(X[:,0],X[:,1], c=y)"
2020-04-27 19:53:55 +00:00
]
},
{
"cell_type": "code",
2020-04-28 18:56:54 +00:00
"execution_count": 4,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [
{
2020-04-28 18:56:54 +00:00
"name": "stdout",
"output_type": "stream",
"text": [
"(200, 2) (200,)\n"
]
2020-04-27 19:53:55 +00:00
}
],
"source": [
2020-04-28 18:56:54 +00:00
"print(X.shape, y.shape)"
2020-04-27 19:53:55 +00:00
]
},
{
"cell_type": "code",
2020-04-28 18:56:54 +00:00
"execution_count": 5,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [],
"source": [
"# Hyperparameters\n",
"input_neurons = 2\n",
"output_neurons = 2\n",
"samples = X.shape[0]\n",
"learning_rate = 0.001\n",
"lambda_reg = 0.01\n"
]
},
{
"cell_type": "code",
2020-04-28 18:56:54 +00:00
"execution_count": 6,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [],
"source": [
2020-04-28 18:56:54 +00:00
"def retreive(model_dict):\n",
" W1 = model_dict['W1']\n",
" b1 = model_dict['b1']\n",
" W2 = model_dict['W2']\n",
" b2 = model_dict['b2']\n",
" \n",
" return W1, b1, W2, b2\n"
2020-04-27 19:53:55 +00:00
]
},
{
"cell_type": "code",
2020-04-28 18:56:54 +00:00
"execution_count": 7,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [],
"source": [
"def forward(x, model_dict):\n",
2020-04-28 18:56:54 +00:00
" W1, b1, W2, b2 = retreive(model_dict)\n",
" z1 = X.dot(W1) + b1\n",
2020-04-27 19:53:55 +00:00
" a1 = np.tanh(z1)\n",
" z2 = a1.dot(W2) + b2\n",
2020-04-28 18:56:54 +00:00
" exp_scores = np.exp(z2)\n",
" softmax = exp_scores / np.sum(exp_scores, axis = 1, keepdims = True) \n",
" \n",
" return z1, a1, softmax\n",
2020-04-27 19:53:55 +00:00
" "
]
},
{
"cell_type": "code",
2020-04-28 18:56:54 +00:00
"execution_count": 8,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [],
"source": [
2020-04-28 18:56:54 +00:00
"def loss(softmax, y, model_dict):\n",
" W1, b1, W2, b2 = retreive(model_dict)\n",
2020-04-27 19:53:55 +00:00
" m = np.zeros(200)\n",
" for i,correct_index in enumerate(y):\n",
" predicted = softmax[i][correct_index]\n",
" m[i] = predicted\n",
2020-04-28 18:56:54 +00:00
" log_prob = -np.log(m)\n",
" loss = np.sum(log_prob)\n",
" reg_loss = lambda_reg / 2 * (np.sum(np.square(W1)) + np.sum(np.square(W2)))\n",
" loss+= reg_loss\n",
" \n",
" return float(loss / y.shape[0])"
2020-04-27 19:53:55 +00:00
]
},
{
"cell_type": "code",
2020-04-28 18:56:54 +00:00
"execution_count": 9,
2020-04-28 16:42:58 +00:00
"metadata": {},
"outputs": [],
"source": [
2020-04-28 18:56:54 +00:00
"def predict(model_dict, x):\n",
" W1, b1, W2, b2 = retreive(model_dict)\n",
2020-04-28 16:42:58 +00:00
" z1 = x.dot(W1) + b1\n",
" a1 = np.tanh(z1)\n",
" z2 = a1.dot(W2) + b2\n",
2020-04-28 18:56:54 +00:00
" exp_scores = np.exp(z2)\n",
" softmax = exp_scores / np.sum(exp_scores, axis = 1, keepdims = True) # (200,2)\n",
" \n",
" return np.argmax(softmax, axis = 1) # (200,)\n"
2020-04-28 16:42:58 +00:00
]
},
{
"cell_type": "code",
2020-04-28 18:56:54 +00:00
"execution_count": 10,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [],
"source": [
2020-04-28 18:56:54 +00:00
"def backpropagation(x, y, model_dict, epochs):\n",
2020-04-28 16:42:58 +00:00
" for i in range(epochs):\n",
2020-04-28 18:56:54 +00:00
" W1, b1, W2, b2 = retreive(model_dict)\n",
" z1, a1, probs = forward(x, model_dict) # a1: (200,3), probs: (200,2)\n",
2020-04-28 16:42:58 +00:00
" delta3 = np.copy(probs)\n",
2020-04-28 18:56:54 +00:00
" delta3[range(x.shape[0]), y] -= 1 # (200,2)\n",
" dW2 = (a1.T).dot(delta3) # (3,2)\n",
" db2 = np.sum(delta3, axis=0, keepdims=True) # (1,2)\n",
" delta2 = delta3.dot(W2.T) * (1 - np.power(np.tanh(z1), 2))\n",
" dW1 = np.dot(x.T, delta2)\n",
" db1 = np.sum(delta2, axis=0)\n",
" \n",
2020-04-28 16:42:58 +00:00
" # Add regularization terms\n",
2020-04-28 18:56:54 +00:00
" dW2 += lambda_reg * np.sum(W2) \n",
" dW1 += lambda_reg * np.sum(W1) \n",
" \n",
" # Update Weights: W = W + (-lr*gradient) = W - lr*gradient\n",
2020-04-28 16:42:58 +00:00
" W1 += -learning_rate * dW1\n",
" b1 += -learning_rate * db1\n",
" W2 += -learning_rate * dW2\n",
" b2 += -learning_rate * db2\n",
2020-04-28 18:56:54 +00:00
" \n",
2020-04-28 16:42:58 +00:00
" # Update the model dictionary\n",
2020-04-28 18:56:54 +00:00
" model_dict = {'W1': W1, 'b1': b1, 'W2': W2, 'b2': b2}\n",
" \n",
" # Print the loss every 50 epochs\n",
2020-04-28 16:42:58 +00:00
" if i%50 == 0:\n",
2020-04-28 18:56:54 +00:00
" print(\"Loss at epoch {} is: {:.3f}\".format(i,loss(probs, y, model_dict)))\n",
2020-04-28 16:42:58 +00:00
" \n",
2020-04-28 18:56:54 +00:00
" return model_dict\n",
2020-04-28 16:42:58 +00:00
" "
2020-04-27 19:53:55 +00:00
]
2020-04-28 16:42:58 +00:00
},
2020-04-28 18:56:54 +00:00
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# Define Initial Weights\n",
"def init_network(input_dim, hidden_dim, output_dim):\n",
" model = {}\n",
" # Xavier Initialization \n",
" W1 = np.random.randn(input_dim, hidden_dim) / np.sqrt(input_dim)\n",
" b1 = np.zeros((1, hidden_dim))\n",
" W2 = np.random.randn(hidden_dim, output_dim) / np.sqrt(hidden_dim)\n",
" b2 = np.zeros((1, output_dim))\n",
" model['W1'] = W1\n",
" model['b1'] = b1\n",
" model['W2'] = W2\n",
" model['b2'] = b2\n",
" return model\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss at epoch 0 is: 0.773\n",
"Loss at epoch 50 is: 0.330\n",
"Loss at epoch 100 is: 0.273\n",
"Loss at epoch 150 is: 0.262\n",
"Loss at epoch 200 is: 0.259\n",
"Loss at epoch 250 is: 0.258\n",
"Loss at epoch 300 is: 0.258\n",
"Loss at epoch 350 is: 0.258\n",
"Loss at epoch 400 is: 0.257\n",
"Loss at epoch 450 is: 0.257\n",
"Loss at epoch 500 is: 0.257\n",
"Loss at epoch 550 is: 0.256\n",
"Loss at epoch 600 is: 0.256\n",
"Loss at epoch 650 is: 0.256\n",
"Loss at epoch 700 is: 0.255\n",
"Loss at epoch 750 is: 0.255\n",
"Loss at epoch 800 is: 0.255\n",
"Loss at epoch 850 is: 0.255\n",
"Loss at epoch 900 is: 0.254\n",
"Loss at epoch 950 is: 0.254\n",
"Loss at epoch 1000 is: 0.254\n",
"Loss at epoch 1050 is: 0.254\n",
"Loss at epoch 1100 is: 0.253\n",
"Loss at epoch 1150 is: 0.253\n",
"Loss at epoch 1200 is: 0.253\n",
"Loss at epoch 1250 is: 0.253\n",
"Loss at epoch 1300 is: 0.252\n",
"Loss at epoch 1350 is: 0.252\n",
"Loss at epoch 1400 is: 0.252\n",
"Loss at epoch 1450 is: 0.252\n"
]
}
],
"source": [
"model_dict = init_network(input_dim=input_neurons,hidden_dim=3,output_dim=output_neurons)\n",
"\n",
"model = backpropagation(X,y,model_dict,epochs=1500)"
]
},
2020-04-28 16:42:58 +00:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
2020-04-27 19:53:55 +00:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}