You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
330 lines
7.8 KiB
Plaintext
330 lines
7.8 KiB
Plaintext
6 months ago
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"id": "2c299e40-b721-4b5e-8e48-b1a89c95e15a",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.datasets import load_digits\n",
|
||
|
"from sklearn.svm import SVC\n",
|
||
|
"from sklearn.linear_model import LinearRegression\n",
|
||
|
"from sklearn.ensemble import RandomForestClassifier"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"id": "7744e7e0-f550-4fd6-8efd-1378e94d0139",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"digits = load_digits()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"id": "6726063c-efca-440d-bd56-3669ce5fb813",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"model_svc = SVC()\n",
|
||
|
"model_lr = LinearRegression()\n",
|
||
|
"model_rf = RandomForestClassifier(random_state=30)#30棵树"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"id": "88cf0342-10d6-44ee-96eb-be12ed3fa29e",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.model_selection import train_test_split"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"id": "11e2a509-201d-4b59-9e31-2bb44559c6e4",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"x_train, x_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.3)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"id": "1f44f871-e1cb-4467-8fc7-d3c0d2ce79bc",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"0.9907407407407407"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model_svc.fit(x_train, y_train)\n",
|
||
|
"model_svc.score(x_test, y_test)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"id": "20500031-5e33-415c-bf39-47db2479eebb",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"0.5545622492919391"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model_lr.fit(x_train, y_train)\n",
|
||
|
"model_lr.score(x_test, y_test)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"id": "3bd09dc7-77e0-487a-8ebb-f808bd8e2238",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"0.9722222222222222"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"model_rf.fit(x_train, y_train)\n",
|
||
|
"model_rf.score(x_test, y_test)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 15,
|
||
|
"id": "28bccfa6-875a-475e-89c7-f6cde6aa4a34",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"def model_score(model, x_train, y_train, x_test, y_test):\n",
|
||
|
" model.fit(x_train, y_train)#训练\n",
|
||
|
" return model.score(x_test, y_test)#测试"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"id": "0cd7da6f-bddf-4617-8eee-00c438e63cbf",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.model_selection import KFold"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"id": "72bd5b93-23c2-49d4-86b9-14f61c2872be",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"kf = KFold(n_splits=3)#随机拆分"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"id": "9bb0a224-cae5-4752-b762-d8a1612f254f",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"[3 4 5 6 7 8] [0 1 2]\n",
|
||
|
"[0 1 2 6 7 8] [3 4 5]\n",
|
||
|
"[0 1 2 3 4 5] [6 7 8]\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"for train_index, test_index in kf.split([1, 2, 3, 4, 5, 6, 7, 8, 9]):\n",
|
||
|
" print(train_index, test_index)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 24,
|
||
|
"id": "1231ff85-33b1-4f39-a1ac-5d61dba6e633",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"score_lr = []\n",
|
||
|
"score_svc = []\n",
|
||
|
"score_rf = []\n",
|
||
|
"for train_index, test_index in kf.split(digits.data):\n",
|
||
|
" x_train, x_test, y_train, y_test = digits.data[train_index], digits.data[test_index], \\\n",
|
||
|
" digits.target[train_index], digits.target[test_index]\n",
|
||
|
" score_lr.append(model_score(model_lr, x_train, y_train, x_test, y_test))\n",
|
||
|
" score_svc.append(model_score(model_svc, x_train, y_train, x_test, y_test))\n",
|
||
|
" score_rf.append(model_score(model_rf, x_train, y_train, x_test, y_test))\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 26,
|
||
|
"id": "0a706285-b4be-4b3c-a454-542c1c8c7675",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"0.502877527901132\n",
|
||
|
"0.9677239844184752\n",
|
||
|
"0.9354479688369505\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"print(sum(score_lr) / len(score_lr))\n",
|
||
|
"print(sum(score_svc) / len(score_svc))\n",
|
||
|
"print(sum(score_rf) / len(score_rf))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 32,
|
||
|
"id": "cfb2a2ff-0ad0-4bca-9e6f-7956eb46605b",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.model_selection import StratifiedKFold"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 34,
|
||
|
"id": "6db430e4-e586-4861-8685-56af7dcec9e2",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"folds = StratifiedKFold(n_splits=3)#对每个类进行平均拆分"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 37,
|
||
|
"id": "6d17b2b0-985c-4101-965b-15b8769e55ed",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"0.5008147357113389\n",
|
||
|
"0.9699499165275459\n",
|
||
|
"0.9382303839732887\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"score_lr = []\n",
|
||
|
"score_svc = []\n",
|
||
|
"score_rf = []\n",
|
||
|
"for train_index, test_index in folds.split(digits.data, digits.target):\n",
|
||
|
" x_train, x_test, y_train, y_test = digits.data[train_index], digits.data[test_index], \\\n",
|
||
|
" digits.target[train_index], digits.target[test_index]\n",
|
||
|
" score_lr.append(model_score(model_lr, x_train, y_train, x_test, y_test))\n",
|
||
|
" score_svc.append(model_score(model_svc, x_train, y_train, x_test, y_test))\n",
|
||
|
" score_rf.append(model_score(model_rf, x_train, y_train, x_test, y_test))\n",
|
||
|
"print(sum(score_lr) / len(score_lr))\n",
|
||
|
"print(sum(score_svc) / len(score_svc))\n",
|
||
|
"print(sum(score_rf) / len(score_rf))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 27,
|
||
|
"id": "40ec04e2-d292-46b7-ba86-da6b03f7f76c",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.model_selection import cross_val_score"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 28,
|
||
|
"id": "e757dc8a-4228-4bb6-a432-f869c6e6870d",
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"array([0.48346048, 0.5583603 , 0.57534522, 0.5056632 , 0.40995457])"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 28,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"cross_val_score(model_lr, digits.data, digits.target)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"id": "ac1037ae-9d3e-48ec-963c-9b0ca075b9cf",
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.8.0"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 5
|
||
|
}
|