You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

330 lines
7.8 KiB
Plaintext

6 months ago
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "2c299e40-b721-4b5e-8e48-b1a89c95e15a",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import load_digits\n",
"from sklearn.svm import SVC\n",
"from sklearn.linear_model import LinearRegression\n",
"from sklearn.ensemble import RandomForestClassifier"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7744e7e0-f550-4fd6-8efd-1378e94d0139",
"metadata": {},
"outputs": [],
"source": [
"digits = load_digits()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6726063c-efca-440d-bd56-3669ce5fb813",
"metadata": {},
"outputs": [],
"source": [
"model_svc = SVC()\n",
"model_lr = LinearRegression()\n",
"model_rf = RandomForestClassifier(random_state=30)#30棵树"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "88cf0342-10d6-44ee-96eb-be12ed3fa29e",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "11e2a509-201d-4b59-9e31-2bb44559c6e4",
"metadata": {},
"outputs": [],
"source": [
"x_train, x_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.3)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1f44f871-e1cb-4467-8fc7-d3c0d2ce79bc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9907407407407407"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_svc.fit(x_train, y_train)\n",
"model_svc.score(x_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "20500031-5e33-415c-bf39-47db2479eebb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.5545622492919391"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_lr.fit(x_train, y_train)\n",
"model_lr.score(x_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "3bd09dc7-77e0-487a-8ebb-f808bd8e2238",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9722222222222222"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_rf.fit(x_train, y_train)\n",
"model_rf.score(x_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "28bccfa6-875a-475e-89c7-f6cde6aa4a34",
"metadata": {},
"outputs": [],
"source": [
"def model_score(model, x_train, y_train, x_test, y_test):\n",
" model.fit(x_train, y_train)#训练\n",
" return model.score(x_test, y_test)#测试"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0cd7da6f-bddf-4617-8eee-00c438e63cbf",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import KFold"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "72bd5b93-23c2-49d4-86b9-14f61c2872be",
"metadata": {},
"outputs": [],
"source": [
"kf = KFold(n_splits=3)#随机拆分"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "9bb0a224-cae5-4752-b762-d8a1612f254f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[3 4 5 6 7 8] [0 1 2]\n",
"[0 1 2 6 7 8] [3 4 5]\n",
"[0 1 2 3 4 5] [6 7 8]\n"
]
}
],
"source": [
"for train_index, test_index in kf.split([1, 2, 3, 4, 5, 6, 7, 8, 9]):\n",
" print(train_index, test_index)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "1231ff85-33b1-4f39-a1ac-5d61dba6e633",
"metadata": {},
"outputs": [],
"source": [
"score_lr = []\n",
"score_svc = []\n",
"score_rf = []\n",
"for train_index, test_index in kf.split(digits.data):\n",
" x_train, x_test, y_train, y_test = digits.data[train_index], digits.data[test_index], \\\n",
" digits.target[train_index], digits.target[test_index]\n",
" score_lr.append(model_score(model_lr, x_train, y_train, x_test, y_test))\n",
" score_svc.append(model_score(model_svc, x_train, y_train, x_test, y_test))\n",
" score_rf.append(model_score(model_rf, x_train, y_train, x_test, y_test))\n"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "0a706285-b4be-4b3c-a454-542c1c8c7675",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.502877527901132\n",
"0.9677239844184752\n",
"0.9354479688369505\n"
]
}
],
"source": [
"print(sum(score_lr) / len(score_lr))\n",
"print(sum(score_svc) / len(score_svc))\n",
"print(sum(score_rf) / len(score_rf))"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "cfb2a2ff-0ad0-4bca-9e6f-7956eb46605b",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import StratifiedKFold"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "6db430e4-e586-4861-8685-56af7dcec9e2",
"metadata": {},
"outputs": [],
"source": [
"folds = StratifiedKFold(n_splits=3)#对每个类进行平均拆分"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "6d17b2b0-985c-4101-965b-15b8769e55ed",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.5008147357113389\n",
"0.9699499165275459\n",
"0.9382303839732887\n"
]
}
],
"source": [
"score_lr = []\n",
"score_svc = []\n",
"score_rf = []\n",
"for train_index, test_index in folds.split(digits.data, digits.target):\n",
" x_train, x_test, y_train, y_test = digits.data[train_index], digits.data[test_index], \\\n",
" digits.target[train_index], digits.target[test_index]\n",
" score_lr.append(model_score(model_lr, x_train, y_train, x_test, y_test))\n",
" score_svc.append(model_score(model_svc, x_train, y_train, x_test, y_test))\n",
" score_rf.append(model_score(model_rf, x_train, y_train, x_test, y_test))\n",
"print(sum(score_lr) / len(score_lr))\n",
"print(sum(score_svc) / len(score_svc))\n",
"print(sum(score_rf) / len(score_rf))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "40ec04e2-d292-46b7-ba86-da6b03f7f76c",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import cross_val_score"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "e757dc8a-4228-4bb6-a432-f869c6e6870d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.48346048, 0.5583603 , 0.57534522, 0.5056632 , 0.40995457])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cross_val_score(model_lr, digits.data, digits.target)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac1037ae-9d3e-48ec-963c-9b0ca075b9cf",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}