You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1081 lines
36 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "5a1fa66c-e538-4215-9ad6-ceedecef6645",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import load_iris"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "17a95bdb-6bb2-4c58-bacc-2ffa368acfdf",
"metadata": {},
"outputs": [],
"source": [
"iris = load_iris()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "fd41f149-4238-41a2-81b6-ce69d2a7e416",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4b907051-63e5-4c3f-a0e5-b951d5060b69",
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame(iris.data, columns=iris.feature_names)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "dc2c3117-18f3-4ab8-b00d-046afc9bb204",
"metadata": {},
"outputs": [],
"source": [
"df[\"flower\"] = iris.target"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "c51f710f-5fdb-4ef8-8108-f9030698c6c7",
"metadata": {},
"outputs": [],
"source": [
"df[\"flower\"] = df[\"flower\"].apply(lambda x: iris.target_names[x])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "259932d6-19fd-4113-917b-55c3d4c0643c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sepal length (cm)</th>\n",
" <th>sepal width (cm)</th>\n",
" <th>petal length (cm)</th>\n",
" <th>petal width (cm)</th>\n",
" <th>flower</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5.1</td>\n",
" <td>3.5</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>4.9</td>\n",
" <td>3.0</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4.7</td>\n",
" <td>3.2</td>\n",
" <td>1.3</td>\n",
" <td>0.2</td>\n",
" <td>setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4.6</td>\n",
" <td>3.1</td>\n",
" <td>1.5</td>\n",
" <td>0.2</td>\n",
" <td>setosa</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5.0</td>\n",
" <td>3.6</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>setosa</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) \\\n",
"0 5.1 3.5 1.4 0.2 \n",
"1 4.9 3.0 1.4 0.2 \n",
"2 4.7 3.2 1.3 0.2 \n",
"3 4.6 3.1 1.5 0.2 \n",
"4 5.0 3.6 1.4 0.2 \n",
"\n",
" flower \n",
"0 setosa \n",
"1 setosa \n",
"2 setosa \n",
"3 setosa \n",
"4 setosa "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "d8e7edb4-7ae3-4baf-b196-2e6082c8dc5d",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "30bc7e98-6509-49b4-a88d-652437b308b5",
"metadata": {},
"outputs": [],
"source": [
"x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "2699689a-4f13-4186-92f8-ce325bac4e3a",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.svm import SVC"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "887940a9-5f38-408d-9b6c-771e51bef55e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9333333333333333"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = SVC(kernel='rbf', C=20, gamma='auto')\n",
"model.fit(x_train, y_train)\n",
"model.score(x_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "70fc50cc-f802-43b7-8d99-267569e9f920",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import cross_val_score"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "e5e5ef79-ed19-49d9-8b85-945c30a305f4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1. , 1. , 0.9 , 0.93333333, 1. ])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cross_val_score(SVC(kernel='linear', C=20, gamma='auto'), iris.data, iris.target, cv=5)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "bc36f23b-8e1f-4d90-a1da-6541035bde30",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.96666667, 1. , 0.96666667, 0.96666667, 1. ])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cross_val_score(SVC(kernel='rbf', C=10, gamma='auto'), iris.data, iris.target, cv=5)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "2d2e4747-7ca2-4195-9ea3-e4b473cb2ca5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.96666667, 1. , 0.9 , 0.96666667, 1. ])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cross_val_score(SVC(kernel='rbf', C=20, gamma='auto'), iris.data, iris.target, cv=5)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "b3b7fb8d-f0df-48e3-b15d-0df4fbc13dcb",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"kernels = [\"rbf\", \"linear\"]\n",
"C = [1, 10, 20]\n",
"avg_score = {}\n",
"for k in kernels:\n",
" for c in C:\n",
" cv_scores = cross_val_score(SVC(kernel=k, C=c, gamma='auto'), iris.data, iris.target, cv=5)\n",
" avg_score[k+\"_\"+str(c)] = np.mean(cv_scores)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "897ec1e1-e14d-44c1-9a56-0181575d69d2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'rbf_1': 0.9800000000000001,\n",
" 'rbf_10': 0.9800000000000001,\n",
" 'rbf_20': 0.9666666666666668,\n",
" 'linear_1': 0.9800000000000001,\n",
" 'linear_10': 0.9733333333333334,\n",
" 'linear_20': 0.9666666666666666}"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"avg_score"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "76c2fcbc-ea50-45e2-8994-8ec375807da2",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import GridSearchCV"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "740bea6d-442b-4c02-8037-986a50eff030",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {color: black;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>GridSearchCV(cv=5, estimator=SVC(gamma=&#x27;auto&#x27;),\n",
" param_grid={&#x27;C&#x27;: [1, 10, 20], &#x27;kernel&#x27;: [&#x27;rbf&#x27;, &#x27;linear&#x27;]})</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item sk-dashed-wrapped\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" ><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">GridSearchCV</label><div class=\"sk-toggleable__content\"><pre>GridSearchCV(cv=5, estimator=SVC(gamma=&#x27;auto&#x27;),\n",
" param_grid={&#x27;C&#x27;: [1, 10, 20], &#x27;kernel&#x27;: [&#x27;rbf&#x27;, &#x27;linear&#x27;]})</pre></div></div></div><div class=\"sk-parallel\"><div class=\"sk-parallel-item\"><div class=\"sk-item\"><div class=\"sk-label-container\"><div class=\"sk-label sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-2\" type=\"checkbox\" ><label for=\"sk-estimator-id-2\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">estimator: SVC</label><div class=\"sk-toggleable__content\"><pre>SVC(gamma=&#x27;auto&#x27;)</pre></div></div></div><div class=\"sk-serial\"><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-3\" type=\"checkbox\" ><label for=\"sk-estimator-id-3\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">SVC</label><div class=\"sk-toggleable__content\"><pre>SVC(gamma=&#x27;auto&#x27;)</pre></div></div></div></div></div></div></div></div></div></div>"
],
"text/plain": [
"GridSearchCV(cv=5, estimator=SVC(gamma='auto'),\n",
" param_grid={'C': [1, 10, 20], 'kernel': ['rbf', 'linear']})"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf = GridSearchCV(SVC(gamma='auto'),\n",
" {'C': [1, 10, 20], \n",
" 'kernel': ['rbf', 'linear']},\n",
" cv=5,\n",
" return_train_score=False)\n",
"clf.fit(iris.data, iris.target)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "df47a4e7-bfbd-4670-a0b2-47acf9de4344",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'mean_fit_time': array([0.00140004, 0.00140204, 0.00119882, 0.00080657, 0.00059376,\n",
" 0.00059996]),\n",
" 'std_fit_time': array([0.00049091, 0.00049348, 0.00040064, 0.00040348, 0.00048495,\n",
" 0.00048986]),\n",
" 'mean_score_time': array([0.00060101, 0.0012032 , 0.00079985, 0.00060096, 0.00060592,\n",
" 0.00039415]),\n",
" 'std_score_time': array([0.00049072, 0.00040417, 0.00040044, 0.00049068, 0.00049487,\n",
" 0.00048292]),\n",
" 'param_C': masked_array(data=[1, 1, 10, 10, 20, 20],\n",
" mask=[False, False, False, False, False, False],\n",
" fill_value='?',\n",
" dtype=object),\n",
" 'param_kernel': masked_array(data=['rbf', 'linear', 'rbf', 'linear', 'rbf', 'linear'],\n",
" mask=[False, False, False, False, False, False],\n",
" fill_value='?',\n",
" dtype=object),\n",
" 'params': [{'C': 1, 'kernel': 'rbf'},\n",
" {'C': 1, 'kernel': 'linear'},\n",
" {'C': 10, 'kernel': 'rbf'},\n",
" {'C': 10, 'kernel': 'linear'},\n",
" {'C': 20, 'kernel': 'rbf'},\n",
" {'C': 20, 'kernel': 'linear'}],\n",
" 'split0_test_score': array([0.96666667, 0.96666667, 0.96666667, 1. , 0.96666667,\n",
" 1. ]),\n",
" 'split1_test_score': array([1., 1., 1., 1., 1., 1.]),\n",
" 'split2_test_score': array([0.96666667, 0.96666667, 0.96666667, 0.9 , 0.9 ,\n",
" 0.9 ]),\n",
" 'split3_test_score': array([0.96666667, 0.96666667, 0.96666667, 0.96666667, 0.96666667,\n",
" 0.93333333]),\n",
" 'split4_test_score': array([1., 1., 1., 1., 1., 1.]),\n",
" 'mean_test_score': array([0.98 , 0.98 , 0.98 , 0.97333333, 0.96666667,\n",
" 0.96666667]),\n",
" 'std_test_score': array([0.01632993, 0.01632993, 0.01632993, 0.03887301, 0.03651484,\n",
" 0.0421637 ]),\n",
" 'rank_test_score': array([1, 1, 1, 4, 5, 6])}"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.cv_results_"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "3842bced-5ac0-4bf2-9adf-e463bfa26db8",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>mean_fit_time</th>\n",
" <th>std_fit_time</th>\n",
" <th>mean_score_time</th>\n",
" <th>std_score_time</th>\n",
" <th>param_C</th>\n",
" <th>param_kernel</th>\n",
" <th>params</th>\n",
" <th>split0_test_score</th>\n",
" <th>split1_test_score</th>\n",
" <th>split2_test_score</th>\n",
" <th>split3_test_score</th>\n",
" <th>split4_test_score</th>\n",
" <th>mean_test_score</th>\n",
" <th>std_test_score</th>\n",
" <th>rank_test_score</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.001400</td>\n",
" <td>0.000491</td>\n",
" <td>0.000601</td>\n",
" <td>0.000491</td>\n",
" <td>1</td>\n",
" <td>rbf</td>\n",
" <td>{'C': 1, 'kernel': 'rbf'}</td>\n",
" <td>0.966667</td>\n",
" <td>1.0</td>\n",
" <td>0.966667</td>\n",
" <td>0.966667</td>\n",
" <td>1.0</td>\n",
" <td>0.980000</td>\n",
" <td>0.016330</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.001402</td>\n",
" <td>0.000493</td>\n",
" <td>0.001203</td>\n",
" <td>0.000404</td>\n",
" <td>1</td>\n",
" <td>linear</td>\n",
" <td>{'C': 1, 'kernel': 'linear'}</td>\n",
" <td>0.966667</td>\n",
" <td>1.0</td>\n",
" <td>0.966667</td>\n",
" <td>0.966667</td>\n",
" <td>1.0</td>\n",
" <td>0.980000</td>\n",
" <td>0.016330</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.001199</td>\n",
" <td>0.000401</td>\n",
" <td>0.000800</td>\n",
" <td>0.000400</td>\n",
" <td>10</td>\n",
" <td>rbf</td>\n",
" <td>{'C': 10, 'kernel': 'rbf'}</td>\n",
" <td>0.966667</td>\n",
" <td>1.0</td>\n",
" <td>0.966667</td>\n",
" <td>0.966667</td>\n",
" <td>1.0</td>\n",
" <td>0.980000</td>\n",
" <td>0.016330</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.000807</td>\n",
" <td>0.000403</td>\n",
" <td>0.000601</td>\n",
" <td>0.000491</td>\n",
" <td>10</td>\n",
" <td>linear</td>\n",
" <td>{'C': 10, 'kernel': 'linear'}</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>0.900000</td>\n",
" <td>0.966667</td>\n",
" <td>1.0</td>\n",
" <td>0.973333</td>\n",
" <td>0.038873</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.000594</td>\n",
" <td>0.000485</td>\n",
" <td>0.000606</td>\n",
" <td>0.000495</td>\n",
" <td>20</td>\n",
" <td>rbf</td>\n",
" <td>{'C': 20, 'kernel': 'rbf'}</td>\n",
" <td>0.966667</td>\n",
" <td>1.0</td>\n",
" <td>0.900000</td>\n",
" <td>0.966667</td>\n",
" <td>1.0</td>\n",
" <td>0.966667</td>\n",
" <td>0.036515</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0.000600</td>\n",
" <td>0.000490</td>\n",
" <td>0.000394</td>\n",
" <td>0.000483</td>\n",
" <td>20</td>\n",
" <td>linear</td>\n",
" <td>{'C': 20, 'kernel': 'linear'}</td>\n",
" <td>1.000000</td>\n",
" <td>1.0</td>\n",
" <td>0.900000</td>\n",
" <td>0.933333</td>\n",
" <td>1.0</td>\n",
" <td>0.966667</td>\n",
" <td>0.042164</td>\n",
" <td>6</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" mean_fit_time std_fit_time mean_score_time std_score_time param_C \\\n",
"0 0.001400 0.000491 0.000601 0.000491 1 \n",
"1 0.001402 0.000493 0.001203 0.000404 1 \n",
"2 0.001199 0.000401 0.000800 0.000400 10 \n",
"3 0.000807 0.000403 0.000601 0.000491 10 \n",
"4 0.000594 0.000485 0.000606 0.000495 20 \n",
"5 0.000600 0.000490 0.000394 0.000483 20 \n",
"\n",
" param_kernel params split0_test_score \\\n",
"0 rbf {'C': 1, 'kernel': 'rbf'} 0.966667 \n",
"1 linear {'C': 1, 'kernel': 'linear'} 0.966667 \n",
"2 rbf {'C': 10, 'kernel': 'rbf'} 0.966667 \n",
"3 linear {'C': 10, 'kernel': 'linear'} 1.000000 \n",
"4 rbf {'C': 20, 'kernel': 'rbf'} 0.966667 \n",
"5 linear {'C': 20, 'kernel': 'linear'} 1.000000 \n",
"\n",
" split1_test_score split2_test_score split3_test_score split4_test_score \\\n",
"0 1.0 0.966667 0.966667 1.0 \n",
"1 1.0 0.966667 0.966667 1.0 \n",
"2 1.0 0.966667 0.966667 1.0 \n",
"3 1.0 0.900000 0.966667 1.0 \n",
"4 1.0 0.900000 0.966667 1.0 \n",
"5 1.0 0.900000 0.933333 1.0 \n",
"\n",
" mean_test_score std_test_score rank_test_score \n",
"0 0.980000 0.016330 1 \n",
"1 0.980000 0.016330 1 \n",
"2 0.980000 0.016330 1 \n",
"3 0.973333 0.038873 4 \n",
"4 0.966667 0.036515 5 \n",
"5 0.966667 0.042164 6 "
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.DataFrame(clf.cv_results_)\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "9c093072-048b-4187-9dd4-b4042edd0633",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>param_C</th>\n",
" <th>param_kernel</th>\n",
" <th>mean_test_score</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>rbf</td>\n",
" <td>0.980000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1</td>\n",
" <td>linear</td>\n",
" <td>0.980000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>10</td>\n",
" <td>rbf</td>\n",
" <td>0.980000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>10</td>\n",
" <td>linear</td>\n",
" <td>0.973333</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>20</td>\n",
" <td>rbf</td>\n",
" <td>0.966667</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>20</td>\n",
" <td>linear</td>\n",
" <td>0.966667</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" param_C param_kernel mean_test_score\n",
"0 1 rbf 0.980000\n",
"1 1 linear 0.980000\n",
"2 10 rbf 0.980000\n",
"3 10 linear 0.973333\n",
"4 20 rbf 0.966667\n",
"5 20 linear 0.966667"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df[[\"param_C\", \"param_kernel\", \"mean_test_score\"]]"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "c184bb27-fc8b-45d4-80fd-39251da1ef3b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['__abstractmethods__',\n",
" '__annotations__',\n",
" '__class__',\n",
" '__delattr__',\n",
" '__dict__',\n",
" '__dir__',\n",
" '__doc__',\n",
" '__eq__',\n",
" '__format__',\n",
" '__ge__',\n",
" '__getattribute__',\n",
" '__getstate__',\n",
" '__gt__',\n",
" '__hash__',\n",
" '__init__',\n",
" '__init_subclass__',\n",
" '__le__',\n",
" '__lt__',\n",
" '__module__',\n",
" '__ne__',\n",
" '__new__',\n",
" '__reduce__',\n",
" '__reduce_ex__',\n",
" '__repr__',\n",
" '__setattr__',\n",
" '__setstate__',\n",
" '__sizeof__',\n",
" '__sklearn_clone__',\n",
" '__str__',\n",
" '__subclasshook__',\n",
" '__weakref__',\n",
" '_abc_impl',\n",
" '_build_request_for_signature',\n",
" '_check_feature_names',\n",
" '_check_n_features',\n",
" '_check_refit_for_multimetric',\n",
" '_estimator_type',\n",
" '_format_results',\n",
" '_get_default_requests',\n",
" '_get_metadata_request',\n",
" '_get_param_names',\n",
" '_get_tags',\n",
" '_more_tags',\n",
" '_parameter_constraints',\n",
" '_repr_html_',\n",
" '_repr_html_inner',\n",
" '_repr_mimebundle_',\n",
" '_required_parameters',\n",
" '_run_search',\n",
" '_select_best_index',\n",
" '_validate_data',\n",
" '_validate_params',\n",
" 'best_estimator_',\n",
" 'best_index_',\n",
" 'best_params_',\n",
" 'best_score_',\n",
" 'classes_',\n",
" 'cv',\n",
" 'cv_results_',\n",
" 'decision_function',\n",
" 'error_score',\n",
" 'estimator',\n",
" 'fit',\n",
" 'get_metadata_routing',\n",
" 'get_params',\n",
" 'inverse_transform',\n",
" 'multimetric_',\n",
" 'n_features_in_',\n",
" 'n_jobs',\n",
" 'n_splits_',\n",
" 'param_grid',\n",
" 'pre_dispatch',\n",
" 'predict',\n",
" 'predict_log_proba',\n",
" 'predict_proba',\n",
" 'refit',\n",
" 'refit_time_',\n",
" 'return_train_score',\n",
" 'score',\n",
" 'score_samples',\n",
" 'scorer_',\n",
" 'scoring',\n",
" 'set_fit_request',\n",
" 'set_params',\n",
" 'transform',\n",
" 'verbose']"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dir(clf)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "451a4f85-38e9-4928-a43c-29a98ef11b9a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9800000000000001"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.best_score_"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "0a18515f-7651-43a0-a8c6-b10888140ab1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'C': 1, 'kernel': 'rbf'}"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.best_params_"
]
},
{
"cell_type": "markdown",
"id": "0bf652ae-3c18-4340-ae99-006ab45e1088",
"metadata": {},
"source": [
"不同模型,谁最好"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "478f9706-f3a2-4bd6-9e90-da52e7f28ccc",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.svm import SVC\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.linear_model import LogisticRegression"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "8398663b-7d4e-4d0b-9a2c-69c7934a5084",
"metadata": {},
"outputs": [],
"source": [
"model_params = {\n",
" 'svm':{\n",
" 'model': SVC(gamma='auto'),\n",
" 'params': {\n",
" 'C': [1, 10, 20],\n",
" 'kernel': ['rbf', 'linear']\n",
" }\n",
" },\n",
" 'random_forest':{\n",
" 'model': RandomForestClassifier(),\n",
" 'params': {\n",
" 'n_estimators': [1, 5, 10]\n",
" }\n",
" },\n",
" 'logistic_regression': {\n",
" 'model': LogisticRegression(solver='liblinear', multi_class='auto'),\n",
" 'params': {\n",
" 'C': [1, 5, 10]\n",
" }\n",
" }\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "f072707d-a124-4f61-bbb3-2f26b5c10eec",
"metadata": {},
"outputs": [],
"source": [
"scores = []\n",
"for model, mp in model_params.items():\n",
" clf = GridSearchCV(mp['model'], mp['params'], cv=5, return_train_score=False)\n",
" clf.fit(iris.data, iris.target)\n",
" scores.append({'model': model,\n",
" 'best_score': clf.best_score_,\n",
" 'best_params': clf.best_params_})\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "e68db3aa-6a69-4662-899b-e73690272851",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'model': 'svm',\n",
" 'best_score': 0.9800000000000001,\n",
" 'best_params': {'C': 1, 'kernel': 'rbf'}},\n",
" {'model': 'random_forest',\n",
" 'best_score': 0.9600000000000002,\n",
" 'best_params': {'n_estimators': 5}},\n",
" {'model': 'logistic_regression',\n",
" 'best_score': 0.9666666666666668,\n",
" 'best_params': {'C': 5}}]"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"scores"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "0279fa4c-7058-41c7-ad16-369bf4206794",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>model</th>\n",
" <th>best_score</th>\n",
" <th>best_params</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>svm</td>\n",
" <td>0.980000</td>\n",
" <td>{'C': 1, 'kernel': 'rbf'}</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>random_forest</td>\n",
" <td>0.960000</td>\n",
" <td>{'n_estimators': 5}</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>logistic_regression</td>\n",
" <td>0.966667</td>\n",
" <td>{'C': 5}</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" model best_score best_params\n",
"0 svm 0.980000 {'C': 1, 'kernel': 'rbf'}\n",
"1 random_forest 0.960000 {'n_estimators': 5}\n",
"2 logistic_regression 0.966667 {'C': 5}"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.DataFrame(scores, columns=['model', 'best_score', 'best_params'])\n",
"df"
]
},
{
"cell_type": "markdown",
"id": "f396a9c8-256e-4d36-813c-9a39e67e39a3",
"metadata": {},
"source": [
"作业:使用手写体数据,找出前面我们课程中学习的最好模型。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47dab0d6-5a29-4e8b-a74e-24a6239d0e56",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}