You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
410 lines
54 KiB
Plaintext
410 lines
54 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "ecf6f699-9799-49dd-b553-2bf87b763ea1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pandas as pd\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import numpy as np"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "08a7ac73-f032-43f6-985c-ef004b0c123d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"df = pd.read_csv(\"insurance_data.csv\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "24ebd375-3904-4c96-bb61-1810dfb24e48",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div>\n",
|
|
"<style scoped>\n",
|
|
" .dataframe tbody tr th:only-of-type {\n",
|
|
" vertical-align: middle;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe tbody tr th {\n",
|
|
" vertical-align: top;\n",
|
|
" }\n",
|
|
"\n",
|
|
" .dataframe thead th {\n",
|
|
" text-align: right;\n",
|
|
" }\n",
|
|
"</style>\n",
|
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|
" <thead>\n",
|
|
" <tr style=\"text-align: right;\">\n",
|
|
" <th></th>\n",
|
|
" <th>age</th>\n",
|
|
" <th>bought_insurance</th>\n",
|
|
" </tr>\n",
|
|
" </thead>\n",
|
|
" <tbody>\n",
|
|
" <tr>\n",
|
|
" <th>0</th>\n",
|
|
" <td>22</td>\n",
|
|
" <td>0</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>1</th>\n",
|
|
" <td>25</td>\n",
|
|
" <td>0</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>2</th>\n",
|
|
" <td>47</td>\n",
|
|
" <td>1</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>3</th>\n",
|
|
" <td>52</td>\n",
|
|
" <td>0</td>\n",
|
|
" </tr>\n",
|
|
" <tr>\n",
|
|
" <th>4</th>\n",
|
|
" <td>46</td>\n",
|
|
" <td>1</td>\n",
|
|
" </tr>\n",
|
|
" </tbody>\n",
|
|
"</table>\n",
|
|
"</div>"
|
|
],
|
|
"text/plain": [
|
|
" age bought_insurance\n",
|
|
"0 22 0\n",
|
|
"1 25 0\n",
|
|
"2 47 1\n",
|
|
"3 52 0\n",
|
|
"4 46 1"
|
|
]
|
|
},
|
|
"execution_count": 27,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"df.head()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "260592d2-b720-4168-a5bd-c3f916d7e508",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<matplotlib.collections.PathCollection at 0x22de0936700>"
|
|
]
|
|
},
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"plt.scatter(df.age, df.bought_insurance, marker='+', c=df.bought_insurance)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "02a35555-40b0-458f-8c67-c061ba3ffc94",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def sigmoid(z):\n",
|
|
" return 1/(1+np.exp(-z))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"id": "b92661b6-b83c-4255-8aa2-ed5a70ed849b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"x = np.arange(-10, 10, 0.001)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"id": "7bba53a6-5925-49e7-9509-1713638da989",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"y = sigmoid(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 33,
|
|
"id": "ff24178a-2cb8-48c6-a480-cf9bd530dbb6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dy = derivative_sigmoid(x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 36,
|
|
"id": "8d966ed9-b435-4440-88cf-07d81d4d2d90",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<matplotlib.legend.Legend at 0x22de0ad5040>"
|
|
]
|
|
},
|
|
"execution_count": 36,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 640x480 with 1 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"plt.plot(x, y)\n",
|
|
"plt.plot(x, dy)\n",
|
|
"plt.legend([\"sigmoid\", \"derivative_sigmoid\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "880b3325-7406-42fb-a68d-86ea1ae0fb12",
|
|
"metadata": {},
|
|
"source": [
|
|
"最小化:\n",
|
|
"$$loss = -\\frac{1}{n}\\sum_{i=1}^n [y_i \\cdot log(\\hat{y_i}) + (1-y_i) \\cdot log(1 - \\hat{y_i})]$$\n",
|
|
"$$\\hat{y_i} = sigmoid(z_i)$$"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9860e2a9-c601-4373-b772-6cf89ea8bb9f",
|
|
"metadata": {},
|
|
"source": [
|
|
"根据链式求导法则:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "44920e6d-9749-4fff-b4fb-6d29350352ff",
|
|
"metadata": {},
|
|
"source": [
|
|
"$$\\frac{\\partial loss}{\\partial w} = \\frac{\\partial loss}{\\partial y_i}\\cdot \\frac{\\partial y_i}{\\partial z_i} \\cdot\\frac{\\partial z_i}{\\partial w}$$\n",
|
|
"$$\\frac{\\partial loss}{\\partial b} = \\frac{\\partial loss}{\\partial y_i}\\cdot \\frac{\\partial y_i}{\\partial z_i}\\cdot \\frac{\\partial z_i}{\\partial b}$$"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "71f01557-4978-410c-b803-9b8935b25b7e",
|
|
"metadata": {},
|
|
"source": [
|
|
"$$\\frac{\\partial loss}{\\partial y_i} = -\\frac{y_i}{\\hat{y_i}} + \\frac{1 - y_i}{1 - \\hat{y_i}}$$\n",
|
|
"$$\\frac{\\partial y_i}{\\partial z} = y_i\\cdot(1 - y_i)$$\n",
|
|
"$$\\frac{\\partial z_i}{\\partial w} = x_i$$\n",
|
|
"$$\\frac{\\partial z_i}{\\partial b} = 1$$"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1b670a1f-1a58-431a-8f4f-7792cf6fdec4",
|
|
"metadata": {},
|
|
"source": [
|
|
"$$\\frac{\\partial loss}{\\partial w} = -\\frac{y_i}{\\hat{y_i}} + \\frac{1 - y_i}{1 - \\hat{y_i}} \\cdot y_i \\cdot (1 - y_i) \\cdot x_i = (\\hat{y_i} - y_i) \\cdot x_i$$\n",
|
|
"$$\\frac{\\partial loss}{\\partial b} = -\\frac{y_i}{\\hat{y_i}} + \\frac{1 - y_i}{1 - \\hat{y_i}} \\cdot y_i \\cdot (1 - y_i) \\cdot x_i = (\\hat{y_i} - y_i) \\cdot 1 = \\hat{y_i} - y_i$$"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 72,
|
|
"id": "9a608877-c041-4441-9982-ef1195633a06",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def logistic_regression(x, y, iterations, learning_rate):\n",
|
|
" w, b = 0, 0\n",
|
|
" for i in range(1, iterations + 1):\n",
|
|
" y_hat = sigmoid(w * x + b)\n",
|
|
" loss = np.mean(-((y * np.log(y_hat)) + (1 - y) * np.log(1 - y_hat)))\n",
|
|
" if i % 1000 == 0:\n",
|
|
" print(f\"Loss:{round(loss,2)}, Iteration: {i}\")\n",
|
|
" dw = np.mean((y_hat - y) * x)\n",
|
|
" db = np.mean(y_hat - y)\n",
|
|
" w -= learning_rate * dw\n",
|
|
" b -= learning_rate * db\n",
|
|
" return w, b"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 74,
|
|
"id": "5099121d-ed49-44f0-b706-5512b9cc1955",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Loss:1.62, Iteration: 1000\n",
|
|
"Loss:0.75, Iteration: 2000\n",
|
|
"Loss:0.58, Iteration: 3000\n",
|
|
"Loss:0.49, Iteration: 4000\n",
|
|
"Loss:0.45, Iteration: 5000\n",
|
|
"Loss:0.43, Iteration: 6000\n",
|
|
"Loss:0.42, Iteration: 7000\n",
|
|
"Loss:0.41, Iteration: 8000\n",
|
|
"Loss:0.4, Iteration: 9000\n",
|
|
"Loss:0.4, Iteration: 10000\n",
|
|
"Loss:0.39, Iteration: 11000\n",
|
|
"Loss:0.39, Iteration: 12000\n",
|
|
"Loss:0.39, Iteration: 13000\n",
|
|
"Loss:0.38, Iteration: 14000\n",
|
|
"Loss:0.38, Iteration: 15000\n",
|
|
"Loss:0.38, Iteration: 16000\n",
|
|
"Loss:0.38, Iteration: 17000\n",
|
|
"Loss:0.38, Iteration: 18000\n",
|
|
"Loss:0.38, Iteration: 19000\n",
|
|
"Loss:0.38, Iteration: 20000\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"w, b = logistic_regression(df.age, df.bought_insurance, 20000, 0.01)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 45,
|
|
"id": "cdac8f71-8ea7-4877-b84e-8beba356d0c9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sklearn.linear_model import LogisticRegression"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 47,
|
|
"id": "15cfeb6b-32ba-4ca1-be89-ea3b1438ecdc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"lr = LogisticRegression()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 48,
|
|
"id": "82f58946-4773-4516-afd0-f3e92b443435",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>LogisticRegression()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">LogisticRegression</label><div class=\"sk-toggleable__content\"><pre>LogisticRegression()</pre></div></div></div></div></div>"
|
|
],
|
|
"text/plain": [
|
|
"LogisticRegression()"
|
|
]
|
|
},
|
|
"execution_count": 48,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"lr.fit(df[[\"age\"]], df.bought_insurance)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 50,
|
|
"id": "152fcad5-e55c-4524-a0a2-c5ee8cf5240a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(array([[0.1354656]]), array([-5.26279696]))"
|
|
]
|
|
},
|
|
"execution_count": 50,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"lr.coef_, lr.intercept_"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4a91cddc-de82-4957-8683-9f56c9485705",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|