pytorch-stuff/Numpy_NN.ipynb

243 lines
45 KiB
Plaintext
Raw Normal View History

2020-04-27 19:53:55 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
2020-04-28 16:42:58 +00:00
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
2020-04-27 19:53:55 +00:00
"import sklearn.datasets"
]
},
{
"cell_type": "code",
2020-04-28 16:42:58 +00:00
"execution_count": 3,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [],
"source": [
"X,y = sklearn.datasets.make_moons(200,noise=0.15)"
]
},
{
"cell_type": "code",
2020-04-28 16:42:58 +00:00
"execution_count": 4,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2020-04-28 16:42:58 +00:00
"<matplotlib.collections.PathCollection at 0x11e1694d0>"
2020-04-27 19:53:55 +00:00
]
},
2020-04-28 16:42:58 +00:00
"execution_count": 4,
2020-04-27 19:53:55 +00:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2020-04-28 16:42:58 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAD4CAYAAADhNOGaAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nOydd3gUVReH37t900gIvQkiVVCaIE2aIAiCSBdEAcUCCoKoKIgiAvIJYqErKha6AoqAIL1KE5DeSyiBhEDK9r3fHxuWhJ30TSHM+zx5YO/M3Dm72cyZOfec3xFSSlRUVFRU7l00uW2AioqKikruojoCFRUVlXsc1RGoqKio3OOojkBFRUXlHkd1BCoqKir3OLrcNiAzFCpUSJYtWza3zVBRUVG5q9i9e/c1KWXhO8fvSkdQtmxZdu3aldtmqKioqNxVCCHOKo37JTQkhJgthIgUQvyXwvaeQoj9QogDQoitQoiHk2w7kzj+rxBCvbqrqKio5DD+WiP4HmidyvbTQBMpZXXgY2DmHdubSSlrSCnr+MkeFRUVFZV04pfQkJRyoxCibCrbtyZ5uR0o5Y/zqqioqKhkndzIGuoHrEjyWgJ/CSF2CyH654I9KioqKvc0ObpYLIRohscRNEoy3EhKGSGEKAKsFkIckVJuVDi2P9AfoEyZMjlib34l/kY8p/afI6xYKKUqFM9tc1RUVHKZHHMEQoiHgG+ANlLKqFvjUsqIxH8jhRC/AXUBH0cgpZxJ4tpCnTp1VKW8TDJ3/G/8NHoReqMOp93J/Q+XZfTStwktXCC3TVNRUcklciQ0JIQoA/wKPCelPJZkPFAIEXzr/0ArQDHzSCXrbF22k5/HLMZutRN/IwGbxc7x3Sf5uOuk3DZNRUUlF/HLE4EQYi7QFCgkhLgAjAL0AFLK6cAHQDgwVQgB4EzMECoK/JY4pgN+kVKu9IdNKr4smvg7tgRbsjGnw8WRHce5eiGKwqXCc8kyFRWV3MRfWUM90tj+IvCiwvgp4GHfI1Syg5irNxXHtXotsdFxqiNQUblHUbWG7iHqPlkTncHX92s0GkpXLpELFqUPu82B2kBJRSX7UB3BPUS3t58mpGAQeqMeACHAGGBkwJd90Rv0fjuPzWLD5XRleZ7tf+zmufIDaBfYk6dDn+eHDxfgcmV9XhUVleTclVpDKpkjrEgBZu6fyG9f/smuVfsoXDqczkOe4sEGlfwy/8GtR5n8ygzOHY5Aq9NSu9VDAFw9H0WNZg/SeWh7CpUomK659m04yJhuk7BZ7AAkxFpY+NlSrPFWXv5fb7/Yq6Ki4kHcjY/cderUkaroXN7iwvFLvFprGNZ4m+J2nUGLOcjM9D0TKFLGR/zQh6HNRrF/wyGfcaPZwKKrszEFGLNss4rKvYYQYreSlI8aGlLxC4sn/Y7D5khxu9PuIv5GAj+OXpiu+SKOX1IcFxrB9SsxmbJRRUVFGdURqPiFMwfP43K6U93H7XKze/X+dM1Xrvp9yhuEIDyd4SUVFZX0oa4RqPiFyvUqcOSfEzjtzlT3K1A4JMVtN67dZOrg79j86w7cLjcajcDtvh26NAUa6f7O0xiM/lvYVlFRUR2Bip94ZlBbVnzzNy6HK8VUT1Ogka5vdVDc5nK6GNRwBJfPROJyeDKDNFoNOoMWISCsaCjd332adi+3yrb3oKJyr6I6AhW/ULhUOF9tH8uMt+awb91BTMEmzEEmrl2IQm/U47S76PJWe5p2a6B4/I4/9xB9+brXCYAnlGQ0Gxgy6xWadmuYU2/Fr0gpib4cgynAQGCBwNw2R0VFEdURqPiN0pVKMub34cnGIs9d5VpENPc9WJrAkIAUjz136AK2BLvPuCXOytlDF/xua06wf+Mh/tdnClGXriPdkprNqvH2nIGqwJ9KnkN1BCrZSpEyhdOVLlq6ckmMAQYssdZk4+YgE6Url8wu8zLFzehYfp+2in/XHqTY/UV4ZlBbylVLLo1+6dQV3m87Nlk67d61B3j3iTFM2z2BRH0tFZU8geoIVHKdqEvX2fTrDp8aBI1WQ1BYII2eqZdLlvkSffk6r9R6m/iYeOxWB5qNGtbN3czIBUOp92Qt735Lp670WTh3OlxEHL/Esd2nqFSnfE6brqKSImr6qEquEn8zgdfqvMP6eVuQSTKEhIBH29Xmy21j81SW0E8fL+JmVCx2q6dmwu1yY0uwM+mlabjdt9NnI45dwunwlcPQaDVEnr2aY/aqqKQH1RGopIvLZyI5vucU9lSKxjLDXz+sJy4mzkebSEroMvSpdEtS5BTb/9idbEH7FvE3Erhy5vYF/qEmVTGaDT77Oe1OHqhVLlttVFHJKKojUEmV61diGNTwffpVHcxbzT+kS5F+rPxurd/mP7ztGHaLsnOZPvSHVI+Ni4lnydcr+Or1b1j94wbsVt/FZn8TFKac+eN2uQkIMXtft+nXgsDQQLQ6rXfMGGDksS71KV6uaLbbqaKSEVRHoJIqI54az9GdJ7FbHSTctJAQa+Hr12dzcOtRv8yf2kLwib2ncdiVncTZQ+d5rvwAvnn3Z5ZNWcWXA76h34NvcuOacs8Ff9FpUDtMgcl1jnR6LQ89VpUChW4XywWFBjJt96c80acZBYuFUuKBYvT9pAdvzX4tW+1TUckMqiNQSZFzRyI4e+i8T9jGbrGx+PM//HKOJ196PNXtSnF2gAkvTCE+Jt7bcc0aZ+XahSi+GzHPL3alRKsXmtKmXwv0Rj2BBcwYA4yUr1GW4T8P8tm3YLEw3pzxMvMvzuKHY1/xzKC2aLVahVlVVHIXvzgCIcRsIUSkEEKx37Dw8KUQ4oQQYr8QolaSbc8LIY4n/jzvD3tU/MP1KzHo9L6JZVLC1QtRfjlHePEw6jyh3KSu3EP3YQ40+YzH30zg5L4z3FnA7HS42LR4u1/sSgkhBK9N7sNPp6fw/tw3+Wr7WL7eMT7Z04CKyt2Gv54Ivgdap7K9DVAh8ac/MA1ACFEQT3/jekBdYJQQIsxPNt3z3IyOZf/GQ1w6dcU75nQ4ObTtKEd3nUyW5aLEAzXL4bD5agfpTXrqtqnpNzuHzHqV0CIFMCQuruqNegKCzQyd9ari/hptyl9brT5n7rgLFgvjkdY1feoHVFTuRvzVs3ijEKJsKrt0AOZIjwjNdiFEqBCiOJ6G96ullNEAQojVeBzKXH/Yda8ipWT2+7/w6+Tl6I16HHYnVR+tSPvXWjGp/wzcTjdSSgILBDB66TtUqHW/4jyBIQE8N6ozP49Z7M3x1xt1hIQH02Fgan4/YxQuFc53R75g5ey1HNlxnDJVS9G2f0vCiyvfE5gDTTzc5EH+XfcfbtdtZ2Yw6XnihaZ+s0tF5V7Bb41pEh3BH1LKagrb/gDGSyk3J77+G3gHjyMwSSnHJI6PBCxSys8U5uiP52mCMmXK1D579qxf7L5bObb7JGt+2ojT7qRJ1wY89FhVb7Xq6h838OWrs7Am3C7Q0um1uF3uZGqe4FnUnBcxA6M55UYvO5bvZtHnf3Aj8ib12tWm85B2uR4KuXYxmjcbj+TGtZu4HC40Wg0P1CzH+FUjUn0vKir3Mik1prlrKoullDOBmeDpUJbL5uQoDruDTYu2s3vNfgqXCsdmsfP71FWenH4pWT1nA82fbcTg6S8jhGDhxN+TOQFIedHV5XKz/ffdNOmqLAYHUK9tbeq1re3X95RVCpUoyPfHvmTXqn1cPh3JAzXLUbV+xTwl3bB79T5mvv0j5w5HEF4ijOc+6MITLzTLbbNUVHzIKUcQAZRO8rpU4lgEnqeCpOPrc8imuwJLvJXBDUdw8eRlrPE2dHqtz0XdGm9j7S+beeKFZlStX4nYqNh0z+92uriZgf1zC5fTc9ef9EKv1Wq9sg7xNxO4fCaSwqXCFRe4cwIpJeePXsRhc3Dj2k1GPT3B23P5ypmrfDXwWxJiLXR8/clcsU9FJSVy6i9mGTBQCDEPz8LwDSnlJSHEKmBskgXiVsDwlCa5F1ny5Z9cOH4Je+IFJaU7e2uCjS1L/qFq/UrUbvUwq+dsSBY/B0AACs9SDzV90M9
2020-04-27 19:53:55 +00:00
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(X[:,0],X[:,1],c=y)"
]
},
{
"cell_type": "code",
2020-04-28 16:42:58 +00:00
"execution_count": 5,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(200, 2)"
]
},
2020-04-28 16:42:58 +00:00
"execution_count": 5,
2020-04-27 19:53:55 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.shape"
]
},
{
"cell_type": "code",
2020-04-28 16:42:58 +00:00
"execution_count": 6,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [],
"source": [
"# Hyperparameters\n",
"input_neurons = 2\n",
"output_neurons = 2\n",
"samples = X.shape[0]\n",
"learning_rate = 0.001\n",
"lambda_reg = 0.01\n"
]
},
{
"cell_type": "code",
2020-04-28 16:42:58 +00:00
"execution_count": 7,
2020-04-27 19:53:55 +00:00
"metadata": {},
2020-04-28 16:42:58 +00:00
"outputs": [
{
"ename": "NameError",
"evalue": "name 'W1' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-7-5283970c2356>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel_dic\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'W1'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mW1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'b1'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mb1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m'W2'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mW2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'b2'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mb2\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mNameError\u001b[0m: name 'W1' is not defined"
]
}
],
2020-04-27 19:53:55 +00:00
"source": [
"model_dic = {'W1': W1, 'b1': b1,'W2': W2, 'b2': b2}"
]
},
{
"cell_type": "code",
2020-04-28 16:42:58 +00:00
"execution_count": null,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [],
"source": [
"def retrieve(model_dict):\n",
" W1 = model_dic['W1']\n",
" b1 = model_dic['b1']\n",
" W2 = model_dic['W2']\n",
" b2 = model_dic['b2']\n",
" return W1,b1,W2,b2\n"
]
},
{
"cell_type": "code",
2020-04-28 16:42:58 +00:00
"execution_count": 12,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [],
"source": [
"def forward(x, model_dict):\n",
" W1,b1,W2,b2 = retrieve(model_dict)\n",
" z1 = x.dot(W1) + b1\n",
" a1 = np.tanh(z1)\n",
" z2 = a1.dot(W2) + b2\n",
" a2 = np.tanh(z2)\n",
" exp_scores = np.exp(a2)\n",
" softmax = exp_scores / np.sum(exp_scores, dim=1, keepdims=True)\n",
2020-04-28 16:42:58 +00:00
" return z1,a1,softmax\n",
2020-04-27 19:53:55 +00:00
" "
]
},
{
"cell_type": "code",
2020-04-28 16:42:58 +00:00
"execution_count": 13,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [],
"source": [
"def loss(softmax, y):\n",
" W1,b1,W2,b2 = retrieve(model_dict)\n",
" m = np.zeros(200)\n",
" for i,correct_index in enumerate(y):\n",
" predicted = softmax[i][correct_index]\n",
" m[i] = predicted\n",
" log_prob = -np.log(predicted)\n",
" softmax_loss = np.sum(log_prob)\n",
" reg_loss = lambda_reg / 2*(np.sum(np.square(W1)) + np.sum(np.square(W2)))\n",
" loss = softmax_loss + reg_loss\n",
" return float(loss/y.shape[0])"
]
},
{
"cell_type": "code",
2020-04-28 16:42:58 +00:00
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def predict(x, model_dict):\n",
" W1,b1,W2,b2 = retrieve(model_dict)\n",
" z1 = x.dot(W1) + b1\n",
" a1 = np.tanh(z1)\n",
" z2 = a1.dot(W2) + b2\n",
" a2 = np.tanh(z2)\n",
" exp_scores = np.exp(a2)\n",
" softmax = exp_scores / np.sum(exp_scores, dim=1, keepdims=True)\n",
" return np.argmax(softmax, axis=1)\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
2020-04-27 19:53:55 +00:00
"metadata": {},
"outputs": [],
"source": [
2020-04-28 16:42:58 +00:00
"def backpropagation(x,y,model_dict,epochs):\n",
" for i in range(epochs):\n",
" W1,b1,W2,b2 = retrieve(model_dict)\n",
" z1,a1,probs = forward(x,model_dict)\n",
" delta3 = np.copy(probs)\n",
" delta3[range(x.shaape[0]),y] -= 1\n",
" dW2 = (a1.T).dot(delta3)\n",
" db2 = np.sum(delta3,axis=0,keepdims=True)\n",
" delta2 = delta3.dot(W2.T) * (1-np.power(np.tanh(z1),2))\n",
" dW1 = np.dot(x.T,delta2)\n",
" db1 = np.sum(delta2,axis=0,keepdims=True)\n",
" # Add regularization terms\n",
" dW2 += lambda_reg * np.sum(W2)\n",
" dW1 += lambda_reg * np.sum(W1)\n",
" # Update weights\n",
" W1 += -learning_rate * dW1\n",
" b1 += -learning_rate * db1\n",
" W2 += -learning_rate * dW2\n",
" b2 += -learning_rate * db2\n",
" # Update the model dictionary\n",
" model_dict = {'W1': W1, 'b1': b1,'W2': W2, 'b2': b2}\n",
" # Print loss every 50 epochs\n",
" if i%50 == 0:\n",
" print(\"Loss at epoch {} is: {}\".format(i,loss(probs,y,model_dict)))\n",
" \n",
" return model_dict\n",
" "
2020-04-27 19:53:55 +00:00
]
2020-04-28 16:42:58 +00:00
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
2020-04-27 19:53:55 +00:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}