You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

410 lines
54 KiB
Plaintext

6 months ago
{
"cells": [
{
"cell_type": "code",
"execution_count": 12,
"id": "ecf6f699-9799-49dd-b553-2bf87b763ea1",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "08a7ac73-f032-43f6-985c-ef004b0c123d",
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv(\"insurance_data.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "24ebd375-3904-4c96-bb61-1810dfb24e48",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>bought_insurance</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>22</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>25</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>47</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>52</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>46</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" age bought_insurance\n",
"0 22 0\n",
"1 25 0\n",
"2 47 1\n",
"3 52 0\n",
"4 46 1"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "260592d2-b720-4168-a5bd-c3f916d7e508",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.collections.PathCollection at 0x22de0936700>"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAeL0lEQVR4nO3df3TV9X348VcC5AYqCbRIwo8odtWiU8CC5mTU0zlTmfWwsh/9cqwrlLXu6GgHZD+EVUg7W8N0OuqBmknr7PluHahntnZaHEsLO66sTBirropacXDUBDgrCY2YuNzP9w+/Xk0JmosJbwKPxzn3nPC578+970/eCfd5PvdHSrIsywIAIJHS1BMAAE5vYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJIannoC/ZHP5+Oll16K0aNHR0lJSerpAAD9kGVZHD58OCZOnBilpcc+/zEkYuSll16Kmpqa1NMAAI7Dvn37YvLkyce8fkjEyOjRoyPi9YOpqKhIPBsAoD86Ojqipqam8Dh+LEMiRt54aqaiokKMAMAQ804vsfACVgAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAIKmiY+Rf/uVfYu7cuTFx4sQoKSmJb3/72++4z5YtW+JDH/pQ5HK5+MAHPhD33nvvcUyVEynLvxL51vMi33peZPlXUk+HU9Sp9HPW32MZ6HEp55jSYBzLQB/3UFjDk2Wti46Rzs7OmD59eqxbt65f4/fs2RNXX311XH755bFr165YunRpfPazn41HH3206MkCAKeeov82zVVXXRVXXXVVv8c3NzfHOeecE7fffntERJx//vnx2GOPxV/91V/FnDlzir17BlmhjLMjb9l4JLL861+WlI468ZPilHMq/Zz191gGatxbx6aaY0qD8X0c6OMejO/jif45O9FrPeh/KG/btm1RX1/fa9ucOXNi6dKlx9ynq6srurq6Cv/u6OgYrOnxC7L9M47edqCu8HVJ9TMncDacqk6ln7P+HstAjXvr2FRzTGkwvo8DfdyD8X080T9nJ3qtB/0FrK2trVFVVdVrW1VVVXR0dMSRI0f63KepqSkqKysLl5qamsGeJgCQyKCfGTkeK1asiIaGhsK/Ozo6BMkJUjJ+1+tfZEcKlVxy5raIkpHpJsUp51T6OevvsQz0uJRzTGkwjmWgj3sorOHJttaDHiPV1dXR1tbWa1tbW1tUVFTEyJF9H3Qul4tcLjfYU6MPbz6f+NaNI0+K54o5dZxKP2f9PZaBHpdyjikNxrEM9HEPhTU82dZ60J+mqauri5aWll7bNm/eHHV1dcfYAwA4nRQdIz//+c9j165dsWvXroh4/a27u3btir1790bE60+xLFiwoDD++uuvj+effz7+9E//NJ5++un42te+Fvfdd18sW7ZsYI4AABjSSrIsy4rZYcuWLXH55ZcftX3hwoVx7733xqc//el44YUXYsuWLb32WbZsWfzkJz+JyZMnx8qVK+PTn/50v++zo6MjKisro729PSoqKoqZLgCQSH8fv4uOkRTECAAMPf19/Pa3aQCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKSOK0bWrVsXU6ZMifLy8qitrY3t27e/7fg1a9bEBz/4wRg5cmTU1NTEsmXL4tVXXz2uCQMAp5aiY2Tjxo3R0NAQjY2NsXPnzpg+fXrMmTMn9u/f3+f4b33rW7F8+fJobGyMp556Kr7xjW/Exo0b48/+7M/e9eQBgKGv6Bi544474rrrrotFixbFBRdcEM3NzTFq1Ki45557+hz/wx/+MGbPnh2f/OQnY8qUKXHllVfGNddc845nUwCA00NRMdLd3R07duyI+vr6N2+gtDTq6+tj27Ztfe7zK7/yK7Fjx45CfDz//PPxyCOPxMc+9rFj3k9XV1d0dHT0ugAAp6bhxQw+ePBg9PT0RFVVVa/tVVVV8fTTT/e5zyc/+ck4ePBgfPjDH44sy+J///d/4/rrr3/bp2mampriS1/6UjFTAwCGqEF/N82WLVvilltuia997Wuxc+fO+Id/+Id4+OGH4+abbz7mPitWrIj29vbCZd++fYM9TQAgkaLOjIwbNy6GDRsWbW1tvba3tbVFdXV1n/usXLkyPvWpT8VnP/vZiIi46KKLorOzM37/938/vvCFL0Rp6dE9lMvlIpfLFTM1AGCIKurMSFlZWcycOTNaWloK2/L5fLS0tERdXV2f+7zyyitHBcewYcMiIiLLsmLnCwCcYoo6MxIR0dDQEAsXLoxZs2bFpZdeGmvWrInOzs5YtGhRREQsWLAgJk2aFE1NTRERMXfu3Ljjjjvi4osvjtra2njuuedi5cqVMXfu3EKUAACnr6JjZP78+XHgwIFYtWpVtLa2xowZM2LTpk2FF7Xu3bu315mQm266KUpKSuKmm26KF198Mc4888yYO3dufOUrXxm4owAAhqySbAg8V9LR0RGVlZXR3t4eFRUVqacDAPRDfx+//W0aACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEkdV4ysW7cupkyZEuXl5VFbWxvbt29/2/GHDh2KxYsXx4QJEyKXy8V5550XjzzyyHFNGAA4tQwvdoeNGzdGQ0NDNDc3R21tbaxZsybmzJkTu3fvjvHjxx81vru7Oz760Y/G+PHj44EHHohJkybFf//3f8eYMWMGYv4AwBBXkmVZVswOtbW1cckll8TatWsjIiKfz0dNTU18/vOfj+XLlx81vrm5OW677bZ4+umnY8SIEcc1yY6OjqisrIz29vaoqKg4rtsAAE6s/j5+F/U0TXd3d+zYsSPq6+vfvIHS0qivr49t27b1uc9DDz0UdXV1sXjx4qiqqooLL7wwbrnllujp6Tnm/XR1dUVHR0evCwBwaioqRg4ePBg9PT1RVVXVa3tVVVW0trb2uc/zzz8fDzzwQPT09MQjjzwSK1eujNtvvz2+/OUvH/N+mpqaorKysnCpqakpZpoAwBAy6O+myefzMX78+Lj77rtj5syZMX/+/PjCF74Qzc3Nx9xnxYoV0d7eXrjs27dvsKcJACRS1AtYx40bF8OGDYu2trZe29va2qK6urrPfSZMmBAjRoyIYcOGFbadf/750draGt3d3VFWVnbUPrlcLnK5XDFTAwCGqKLOjJSVlcXMmTOjpaWlsC2fz0dLS0vU1dX1uc/s2bPjueeei3w+X9j2zDPPxIQJE/oMEQDg9FL00zQNDQ2xfv36+OY3vxlPPfVU3HDDDdHZ2RmLFi2KiIgFCxbEihUrCuNvuOGG+J//+Z9YsmRJPPPMM/Hwww/HLbfcEosXLx64owAAhqyiP2dk/vz5ceDAgVi1alW0trbGjBkzYtOmTYUXte7duzdKS99snJqamnj00Udj2bJlMW3atJg0aVIsWbIkbrzxxoE7CgBgyCr6c0ZS8Dk
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(df.age, df.bought_insurance, marker='+', c=df.bought_insurance)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "02a35555-40b0-458f-8c67-c061ba3ffc94",
"metadata": {},
"outputs": [],
"source": [
"def sigmoid(z):\n",
" return 1/(1+np.exp(-z))"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "b92661b6-b83c-4255-8aa2-ed5a70ed849b",
"metadata": {},
"outputs": [],
"source": [
"x = np.arange(-10, 10, 0.001)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "7bba53a6-5925-49e7-9509-1713638da989",
"metadata": {},
"outputs": [],
"source": [
"y = sigmoid(x)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "ff24178a-2cb8-48c6-a480-cf9bd530dbb6",
"metadata": {},
"outputs": [],
"source": [
"dy = derivative_sigmoid(x)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "8d966ed9-b435-4440-88cf-07d81d4d2d90",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x22de0ad5040>"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAABXPElEQVR4nO3dd3wUdf7H8dduyiYhJCEJpEAgdFF6iwFBlJwIyoGcisoJInbOU1FPUQHRE6yAP8tx54no2VAPEIUDEVGkSEfpHUJJQk+DtN35/bFkIZKEJCSZ3c37+XjsI7sz39n9TDbJvjPf73zHYhiGgYiIiIhJrGYXICIiIjWbwoiIiIiYSmFERERETKUwIiIiIqZSGBERERFTKYyIiIiIqRRGRERExFQKIyIiImIqX7MLKAuHw8Hhw4epXbs2FovF7HJERESkDAzDIDMzk9jYWKzWko9/eEQYOXz4MHFxcWaXISIiIhVw4MABGjRoUOJ6jwgjtWvXBpw7ExISYnI1IiIiUhYZGRnExcW5PsdL4hFhpLBrJiQkRGFERETEw1xsiIUGsIqIiIipFEZERETEVAojIiIiYiqPGDNSFna7nfz8fLPLELlkfn5++Pj4mF2GiEi18YowkpWVxcGDBzEMw+xSRC6ZxWKhQYMGBAcHm12KiEi18PgwYrfbOXjwIEFBQdStW1eToolHMwyDo0ePcvDgQZo3b64jJCJSI3h8GMnPz8cwDOrWrUtgYKDZ5Yhcsrp167Jv3z7y8/MVRkSkRvCaAaw6IiLeQj/LIlLTeE0YEREREc9U7jCyZMkS+vfvT2xsLBaLhdmzZ190mx9//JGOHTtis9lo1qwZ06dPr0CpNcddd93FwIEDzS4DgPj4eKZMmVJqm7L+HIiIiBSn3GNGsrOzadeuHXfffTeDBg26aPu9e/dyww038MADD/DJJ5+waNEi7rnnHmJiYujTp0+FivZ2b775ptucGbR69Wpq1apldhkiIuLFyh1G+vbtS9++fcvcfurUqTRu3Jg33ngDgFatWrF06VImT56sMFKC0NBQs0twqVu3rtkliIiIl6vys2lWrFhBUlJSkWV9+vTh0UcfLXGb3NxccnNzXY8zMjKqqjxTffXVV4wfP55du3YRFBREhw4d+Prrrxk5ciSnTp1ydX1kZmbywAMPMHv2bEJCQvjb3/7G119/Tfv27V1dKPHx8dxzzz3s2LGDmTNnEhERwVtvvUViYiL33HMPixYtokmTJkybNo3OnTu7avjvf//L2LFj2bVrFzExMTz88MM8/vjjrvXx8fE8+uijrvdr586djBgxglWrVtGkSRPefPPN6vp2iYhcwOEwyLM7yLM7sNsNChwGdodBvt2B3eF8XOBwUGA3zj4+//65xw7DwGHg+moYBsZ5jx2GgeFaV/TxxbZxLgMD532AIse+zy40ij7EOLvk99ucf+C8sA0XtCl9W+O8CgqXjbiqMXHhQeX59leaKg8jqampREVFFVkWFRVFRkYGZ86cKfZ03IkTJzJ+/PgKvZ5hGJzJt1do20sV6OdT5jMhUlJSuP3223n11Ve56aabyMzM5Oeffy62e2bUqFEsW7aMOXPmEBUVxdixY1m3bh3t27cv0m7y5MlMmDCBMWPGMHnyZO688066devG3XffzWuvvcZTTz3F0KFD2bx5MxaLhbVr13Lrrbfy/PPPM3jwYJYvX85DDz1EREQEd9111wV1OBwOBg0aRFRUFCtXriQ9Pb3UUCkiAs6/y5m5BZzKzicjJ5/MnAKycwvIzisgK9d5PyvX7lyWW0B2np2cfDu5BQ5yC78WuW8nN9+5LM/uMHv3vMaA9rHeG0YqYvTo0YwaNcr1OCMjg7i4uDJteybfzuVjF1RVaaXa8kIfgvzL9i1NSUmhoKCAQYMG0ahRIwDatGlzQbvMzEw+/PBDPv30U3r37g3ABx98QGxs7AVt+/Xrx/333w/A2LFj+cc//kGXLl245ZZbAHjqqadITEwkLS2N6OhoJk2aRO/evRkzZgwALVq0YMuWLbz22mvFhpHvv/+ebdu2sWDBAtfrT5gwoVzddiLiHQrsDo5m5ZKankNaRi5HMnNITc/hWFYuJ0/nc+p0nuvrqdP5FDiqbxycr9WCr48FX6sVH6sFPx8LPlbnY1/X/XOPfa0WrBYLVqsFqwUsWLBawWqxYLEULiv62Hm/9K/nPxc4HwMU/s9qwfK7x4Xri/5Te/H257e1FFlHGbe1YCEqJKAc3+XKVeVhJDo6mrS0tCLL0tLSCAkJKXGSMpvNhs1mq+rSTNWuXTt69+5NmzZt6NOnD9dddx0333wzderUKdJuz5495Ofn07VrV9ey0NBQWrZsecFztm3b1nW/8GjU+QGncNmRI0eIjo5m69atDBgwoMhzdO/enSlTpmC32y+YcGvr1q3ExcUVCUKJiYnl3XUR8QCGYXA0M5f9J06z//hpko9nu+4fOnWGY1m5lHecfaCfD6GBftSy+RBs86XW2Zvzvo/zvr8vQTZfAv18sPlasflZsfmeve9rJcDP54Jl/r5W/Hys+FqdQUNz9XieKg8jiYmJzJs3r8iyhQsXVtmHWKCfD1teMGdgbKBf2WfL9PHxYeHChSxfvpzvvvuOt956i2effZaVK1dW+PX9/Pxc9wt/GYtb5nDosKaInJN+Op9tqRlsT8tkW2om21Mz2ZGaSWZuQanb+Vot1KttIyo0gKjaAUSHBhAZ7E9YkD91gvypU8vP+TXIn7AgPwLK8TdSapZyh5GsrCx27drlerx37142bNhAeHg4DRs2ZPTo0Rw6dIiPPvoIgAceeIC3336bv/3tb9x999388MMPfPHFF8ydO7fy9uI8FoulzF0lZrNYLHTv3p3u3bszduxYGjVqxKxZs4q0adKkCX5+fqxevZqGDRsCkJ6ezo4dO+jZs+clvX6rVq1YtmxZkWXLli2jRYsWxU5D3qpVKw4cOEBKSgoxMTEA/PLLL5dUg4hUrwK7g22pmaw/cIr1ySfZkHyKPceyi21rtUBsWCCNIoJoGF6LRhFBNAoPIi48iKiQACJq+WO16iiEXLpyf2qvWbOGa665xvW4cGzHsGHDmD59OikpKSQnJ7vWN27cmLlz5/LYY4/x5ptv0qBBA/7973/X+NN6V65cyaJFi7juuuuoV68eK1eu5OjRo7Rq1YrffvvN1a527doMGzaMJ598kvDwcOrVq8e4ceOwWq2XfCjy8ccfp0uXLrz44osMHjyYFStW8Pbbb/Puu+8W2z4pKYkWLVowbNgwXnvtNTIyMnj22WcvqQYRqVqGYbDzSBZLdx5j2a5j/LLnONl5Fw7yb1AnkJZRtWkZfe7WJDIYf19N1C1Vr9xhpFevXqVOyFXc7Kq9evVi/fr15X0prxYSEsKSJUuYMmUKGRkZNGrUiDfeeIO+ffsyY8aMIm0nTZrEAw88wI033ug6tffAgQMEBFzaYKOOHTvyxRdfMHbsWF588UViYmJ44YUXih28CmC1Wpk1axYjRoyga9euxMfH83//939cf/31l1SHiFSuvAIHy3cfY8HmNBZtTeNIZm6R9SEBvrRvWIf2cWF0aBhG+wZh1Knlb1K1ImAx3GWqz1JkZGQQGhpKeno6ISEhRdbl5OSwd+9eGjdufMkfzp4iOzub+vXr88YbbzBixAizy5FKVhN/puXSFdgdLNl5lDkbDrNo2xEyc86N97D5WunaOJyrmkVyVfNIWkWHqHtFqkVpn9/n84zBFTXc+vXr2bZtG127diU9PZ0XXngB4IIzYUSk5tl1JJMv1xxk5vpDHD3vCEjd2jauuzyKPldE07VxuAaPiltTGPEQr7/+Otu3b8ff359OnTrx888/ExkZaXZZImICh8Pgh21HeH/pXlbsOe5aHl7LnwHtY7mxbQwd4uro6Id4DIURD9ChQwfWrl1rdhkiYrLcAjtfrDnItKV72Xv2DBgfq4VrWtbjls4NuKZlPQ04FY+kMCIi4ubyChx8ufYAb/+wi5T0HABqB/hyR9eGDO0WT/2w4ieQFPEUCiMiIm7KMAy++S2FV+dv4+DJMwBEhwTwwNVNuKV
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(x, y)\n",
"plt.plot(x, dy)\n",
"plt.legend([\"sigmoid\", \"derivative_sigmoid\"])"
]
},
{
"cell_type": "markdown",
"id": "880b3325-7406-42fb-a68d-86ea1ae0fb12",
"metadata": {},
"source": [
"最小化:\n",
"$$loss = -\\frac{1}{n}\\sum_{i=1}^n [y_i \\cdot log(\\hat{y_i}) + (1-y_i) \\cdot log(1 - \\hat{y_i})]$$\n",
"$$\\hat{y_i} = sigmoid(z_i)$$"
]
},
{
"cell_type": "markdown",
"id": "9860e2a9-c601-4373-b772-6cf89ea8bb9f",
"metadata": {},
"source": [
"根据链式求导法则:"
]
},
{
"cell_type": "markdown",
"id": "44920e6d-9749-4fff-b4fb-6d29350352ff",
"metadata": {},
"source": [
"$$\\frac{\\partial loss}{\\partial w} = \\frac{\\partial loss}{\\partial y_i}\\cdot \\frac{\\partial y_i}{\\partial z_i} \\cdot\\frac{\\partial z_i}{\\partial w}$$\n",
"$$\\frac{\\partial loss}{\\partial b} = \\frac{\\partial loss}{\\partial y_i}\\cdot \\frac{\\partial y_i}{\\partial z_i}\\cdot \\frac{\\partial z_i}{\\partial b}$$"
]
},
{
"cell_type": "markdown",
"id": "71f01557-4978-410c-b803-9b8935b25b7e",
"metadata": {},
"source": [
"$$\\frac{\\partial loss}{\\partial y_i} = -\\frac{y_i}{\\hat{y_i}} + \\frac{1 - y_i}{1 - \\hat{y_i}}$$\n",
"$$\\frac{\\partial y_i}{\\partial z} = y_i\\cdot(1 - y_i)$$\n",
"$$\\frac{\\partial z_i}{\\partial w} = x_i$$\n",
"$$\\frac{\\partial z_i}{\\partial b} = 1$$"
]
},
{
"cell_type": "markdown",
"id": "1b670a1f-1a58-431a-8f4f-7792cf6fdec4",
"metadata": {},
"source": [
"$$\\frac{\\partial loss}{\\partial w} = -\\frac{y_i}{\\hat{y_i}} + \\frac{1 - y_i}{1 - \\hat{y_i}} \\cdot y_i \\cdot (1 - y_i) \\cdot x_i = (\\hat{y_i} - y_i) \\cdot x_i$$\n",
"$$\\frac{\\partial loss}{\\partial b} = -\\frac{y_i}{\\hat{y_i}} + \\frac{1 - y_i}{1 - \\hat{y_i}} \\cdot y_i \\cdot (1 - y_i) \\cdot x_i = (\\hat{y_i} - y_i) \\cdot 1 = \\hat{y_i} - y_i$$"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "9a608877-c041-4441-9982-ef1195633a06",
"metadata": {},
"outputs": [],
"source": [
"def logistic_regression(x, y, iterations, learning_rate):\n",
" w, b = 0, 0\n",
" for i in range(1, iterations + 1):\n",
" y_hat = sigmoid(w * x + b)\n",
" loss = np.mean(-((y * np.log(y_hat)) + (1 - y) * np.log(1 - y_hat)))\n",
" if i % 1000 == 0:\n",
" print(f\"Loss:{round(loss,2)}, Iteration: {i}\")\n",
" dw = np.mean((y_hat - y) * x)\n",
" db = np.mean(y_hat - y)\n",
" w -= learning_rate * dw\n",
" b -= learning_rate * db\n",
" return w, b"
]
},
{
"cell_type": "code",
"execution_count": 74,
"id": "5099121d-ed49-44f0-b706-5512b9cc1955",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss:1.62, Iteration: 1000\n",
"Loss:0.75, Iteration: 2000\n",
"Loss:0.58, Iteration: 3000\n",
"Loss:0.49, Iteration: 4000\n",
"Loss:0.45, Iteration: 5000\n",
"Loss:0.43, Iteration: 6000\n",
"Loss:0.42, Iteration: 7000\n",
"Loss:0.41, Iteration: 8000\n",
"Loss:0.4, Iteration: 9000\n",
"Loss:0.4, Iteration: 10000\n",
"Loss:0.39, Iteration: 11000\n",
"Loss:0.39, Iteration: 12000\n",
"Loss:0.39, Iteration: 13000\n",
"Loss:0.38, Iteration: 14000\n",
"Loss:0.38, Iteration: 15000\n",
"Loss:0.38, Iteration: 16000\n",
"Loss:0.38, Iteration: 17000\n",
"Loss:0.38, Iteration: 18000\n",
"Loss:0.38, Iteration: 19000\n",
"Loss:0.38, Iteration: 20000\n"
]
}
],
"source": [
"w, b = logistic_regression(df.age, df.bought_insurance, 20000, 0.01)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "cdac8f71-8ea7-4877-b84e-8beba356d0c9",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "15cfeb6b-32ba-4ca1-be89-ea3b1438ecdc",
"metadata": {},
"outputs": [],
"source": [
"lr = LogisticRegression()"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "82f58946-4773-4516-afd0-f3e92b443435",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-r
],
"text/plain": [
"LogisticRegression()"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lr.fit(df[[\"age\"]], df.bought_insurance)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "152fcad5-e55c-4524-a0a2-c5ee8cf5240a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[0.1354656]]), array([-5.26279696]))"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lr.coef_, lr.intercept_"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a91cddc-de82-4957-8683-9f56c9485705",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}