197 lines
40 KiB
Plaintext
197 lines
40 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import numpy as np"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import matplotlib.pyplot as plt"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import sklearn.datasets"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"X,y = sklearn.datasets.make_moons(200,noise=0.15)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"<matplotlib.collections.PathCollection at 0x11f466f90>"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nOydeZzN5RfH38/37nd2xposKSIlOxVKIkWUkpRUWkhJqxItPyVKRVlalBYhiuxKtNh3yb6vYYwZs939fp/fH3cMd+69Y8bcuTPM9/16edV8l+d7LnfO93nOc87nCCklGhoaGhqXPkpxG6ChoaGhERk0h6+hoaFRStAcvoaGhkYpQXP4GhoaGqUEzeFraGholBL0xW1AKBITE2X16tWL2wwNDQ2Ni4r169cnSynLBTtXYh1+9erVWbduXXGboaGhoXFRIYQ4GOqcFtLR0NDQKCVoDl9DQ0OjlKA5fA0NDY1SgubwNTQ0NEoJmsPXiBiHdhzl36XbsWfai9sUDY1SSYnN0tG4dDh1LJUhnd7j0I6j6PQ6vG4vjw17gHue61jcpmlolCq0Gb5GkfNG5xHs23wQp82FLd2O0+5i4uCpbPh9c3GbpqFRqtAcvkaRcmT3MQ5uPYzXo/odd2Q5+XnUvGKySkOjdKI5fI0iJT05HZ1BF/Rc6vHTEbZGQ6N0ozl8jSLlivrVUb1qwHGD2UDzTo2KwSINjdKL5vA1/JBS8s+fW5kxeh6r563H6/UWajyz1cRTIx/GZDXlHDOYDMSXi6XLsx0Ka66GhkYB0LJ0NHKwZ9p56da3Obz9KB63F71RT1xiDKOWvUPZSgkXPG7Hp9px+dWXMWPUPE4dS6XZHQ3p8mwHYhKiw2i9hobG+dAcvkYOEwdPZf/mg7idHgDcTjcuu5MPHx/PsHmDCjV2/dbXUL/1NeEwU0ND4wLRQjoaOSz+YWmOsz+D16OyYdFmXE53MVmloaERLjSHr5GD1xM8Xi+lRKqBG68aGhoXF5rD18jhpnuaBaRQCiGo27wWJospxF0aGhoXC5rDL2X8+s0f9KzZjzssPejb+BU2Lvk359zjwx+k3GVlsUSbATBHmYgpE8WLXz9dXOZqaGiEESGlLG4bgtK4cWOpdbwKLzM/nc9Xr03GaXPmHDNZjAxb8DrXtaoLgMvpZulPq9i1bg+XX12FNj1uwhpjiaid9iwHcz/7jaU/ryI6IYrO/TrQ7I6GEbVBQ+NiRQixXkrZOOg5zeGXDrxeL/eW603m6ayAc9fcWJtRS98pBqsCcdqd9Gv6Gsf3ncBpdwG+lca9L95Fr7e6FbN1Gholn7wcvhbSKSVkpGTmONDcHNx6JMLWhGbRd39zfH+Sn62OLCfT3v+F1KS0YrRMQ+PiR3P4pYTo+Cj0xuCaNpWuqBBha0Kzet56v5DTGfRGPdtX7ioGizQ0Lh00h19K0Bv03PfiXX4SBwAmq5FHhnaPqC1Hdv3Hq7e/w+3G7twV15Ox/b/Cafc5+TKVElB0gV9LqUpiy2qVuRoahUFz+KWIh4bcy0ND7iU6PgqhCMpfnsjLE5+haYcGEbPh9Mk0nm0xiA2LNuP1eLFnOJg/YTFD7hoOQKe+7TAY/QvAhRDElo2h7g21I2anhsaliCatUIoQQtB9YBfuf6UzbpcHo8kQcRvmfr4Il93FuckCLoebbSt3sf/fg1x5fQ0GfPEUn/T9EqEIVK9K4mVleXfeayiKNj/R0CgMmsMvhQghisXZA+xevw+XI1CmQafTcXDbEWpcW422D7aiVdfm7Fq3l6g4K9XrVUUIUQzWamhcWmhTJo2IcmWDGhjNgS8br1fl8qsvy/nZaDZS76Y61Li2mubsNTTChObwNSJKx6duw2AycK4PN5gM1G5Sk5r1qxebXRoapQHN4ZcwHDYn875YxDv3f8QXr3zPf3uPF7dJYSWhQjyjV7zLda2vQSgCo8XIbb1a887c14rbNA2NSx6t0rYEkZWWRb+mr3HqvxQcWU70Bh06g563ZrxM43b1i9u8sHPmu1eSQzZJh5PZvX4f5asmcmWDGiXaVg0NyLvSVtu0LUFM+2A2SYeScWdrz3vcXjxuL+8/MoapRz6/5LJUSrLzVFWV0X2+4PdJf6M36lG9KlVqVea9ha8TXy6uuM3T0LggLi0PcpHz98+rcpz9udgz7BzZdawYLCq9zP18EYsnL8PlcGNLt+PIcnJgyyGGP/RJcZumoXHBaA6/BGGJCq45r3pVzCHOaRQNsz5dECDx4HF72fzXNtJTMorJKg2NwqE5/BLEXf06BEgfKDqF6vUup/zlicVkVckl5Xgq41/4hsfqDuClNm+xet76sI1ty7AHPS4UgSMrUOtHQ+NiQHP4JYh2vVrT5oEbMZoNWGIsWGLMlK+ayJs/vVTcppU4UpPSeOr6l5k9diGHdxzlnz+38k73j5n+4eywjN+8Y6OA7l8A8eXjKFelbFieoaERabQsnRLIsX0n2L56N2UrJXBtqzqX3GZtOPhy4CRmjJ6Hx+XfdN1kMTL9xAQs0YVr2pJyPJW+jQaSdToLp92FTq9Db9Tzv1kDaXjrtYUa+wxSSlbNXc/scb9iS7fR6r4W3PnkbZitWvhO48LRsnQuMipdUaFESRaXRNYv+ifA2QPoDDr2bzlM3ea1CjV+mYoJTNjyEQsmLGbTH1uofGVFujzTgSq1Khdq3HP5etBkfhmzICdEtHfTAX779k8+XfVesUlfaFzaaA6/hOP1ejlx4CTR8VHElo3J1z1SSvZs3E9GSia1m15JVKw138/zuD0c3vkfMWWiSaxc5kLNLnLKVSnL3k0HAo573F4SKoQnbTImIZpuL3em28udwzLeuST/l8LPo+b5ZWU57S7+23OcP6Yso/0jt4T9mRoamsMvwSydsZrRfb7AaXfi9Xi5vs21vDapPzEJoXXhj+0/waAO75L8Xyo6RcHj9vDYez24p/+d533e7z/8zZhnvkJVVTxuL9e0qM2QaS/k+0UTSe59sRMbl2zxy6TRG3TUblyTSjVCr44O7TjKqjnrMJgMtLy3ebG91LYu34nBqA9Iw3VkOVk9b4Pm8DWKBC04XELZtX4vIx7+hLTkdBxZTtxODxsX/8ubXd4PeY+UkkEdhnF0z3EcmQ6y0m047S6+HjSFzX9vy/N521buZNRTn5OVZsOe4cDtcLNl2Xbe6Dwi3B8tKKqqsn31bjb9sSWnGUpe1G99Df0+eRRr9ua20WygXss6vDXj5ZD3fP36ZPo2eoWJg6cw4dVJ9LrqWRZPXhrOj5FvYstGIwncP1N0CmUqxReDRRqlgbDM8IUQXwMdgSQpZb0g5wUwGrgDsAGPSCk3hOPZlyo/fTQXl91/9udxedi1bi9Hdh+jylWVAu7Zu+kAyUdPIVV/R+KyO/nl0wVc16puyOdN/3AOrlw9bz1uL3s27g/5vHCx/9+DvH7ne2SmZSGEQKqSF77sw83335jnfR0eu5VbH2zF4R1HiUuMIfGy0NkzO9fuYcboeWc/o9sLwEePj6dJ++sjvoq5rnVdrLFWHJkOzs2bMJj0dHyqXURt0Sg9hGuG/w1wex7nOwBXZf95Ehgfpudeshzbd4JgGVR6o55TR1OC3pOekolOH5hKKCWcPk8D8JOHTxEsYUtn0HHiQBKnT6bl2LNyzjr6NHyZu8s+wgut32DL8h15jr1r/V6ebzWEOyw9uP+yJ5k+cjaqqgK+PYOX2/6Pk0dOYc9wYEu3Y890MPKxcRzacTTPcQGMJgM161fP09kD/DF1ecALFECn17F6XuTnHjqdjpGL36RSzYqYo0xYYy1YYyy89NXTVL/m8ojbo1E6CMsMX0r5txCieh6XdAa+kz6PsUoIES+EqCSl1PQCQtCgTT32btqP2+mfieJ2uqlxXdWg91zd9MqgmSsmi5Eb726a5/Ma3XYd+zYfDIgp2zMdDO40HCEgoWI8N3e7kVljF+C0+WbK/y7dzqvthjL818HUu6lOwLgHtx/hxZvfzMlESTmWyrdvTSP56Cn6fvwo6xdtDion4XF7mD/hd/qM7JWn3flFSgkCckd
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"plt.scatter(X[:,0],X[:,1],c=y)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"(200, 2)"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 13,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"X.shape"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Hyperparameters\n",
|
||
|
"input_neurons = 2\n",
|
||
|
"output_neurons = 2\n",
|
||
|
"samples = X.shape[0]\n",
|
||
|
"learning_rate = 0.001\n",
|
||
|
"lambda_reg = 0.01\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"model_dic = {'W1': W1, 'b1': b1,'W2': W2, 'b2': b2}"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def retrieve(model_dict):\n",
|
||
|
" W1 = model_dic['W1']\n",
|
||
|
" b1 = model_dic['b1']\n",
|
||
|
" W2 = model_dic['W2']\n",
|
||
|
" b2 = model_dic['b2']\n",
|
||
|
" return W1,b1,W2,b2\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 19,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def forward(x, model_dict):\n",
|
||
|
" W1,b1,W2,b2 = retrieve(model_dict)\n",
|
||
|
" z1 = x.dot(W1) + b1\n",
|
||
|
" a1 = np.tanh(z1)\n",
|
||
|
" z2 = a1.dot(W2) + b2\n",
|
||
|
" a2 = np.tanh(z2)\n",
|
||
|
" exp_scores = np.exp(a2)\n",
|
||
|
" softmax = exp_scores / np.sum(exp_scores, dim=1, keepdims=True)\n",
|
||
|
" return softmax\n",
|
||
|
" "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 20,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def loss(softmax, y):\n",
|
||
|
" W1,b1,W2,b2 = retrieve(model_dict)\n",
|
||
|
" m = np.zeros(200)\n",
|
||
|
" for i,correct_index in enumerate(y):\n",
|
||
|
" predicted = softmax[i][correct_index]\n",
|
||
|
" m[i] = predicted\n",
|
||
|
" log_prob = -np.log(predicted)\n",
|
||
|
" softmax_loss = np.sum(log_prob)\n",
|
||
|
" reg_loss = lambda_reg / 2*(np.sum(np.square(W1)) + np.sum(np.square(W2)))\n",
|
||
|
" loss = softmax_loss + reg_loss\n",
|
||
|
" return float(loss/y.shape[0])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"softmax = np.random.randn(200,2)\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.7.6"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 4
|
||
|
}
|