{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "c653c8dc-4405-4665-ad7d-d6efb4f8c4ef", "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import load_iris" ] }, { "cell_type": "code", "execution_count": 2, "id": "4460cd03-9a31-48bb-804a-1b6abaa1355d", "metadata": {}, "outputs": [], "source": [ "iris = load_iris()" ] }, { "cell_type": "code", "execution_count": 3, "id": "95a05439-f1d0-48f3-832a-25e3d7045341", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['DESCR',\n", " 'data',\n", " 'data_module',\n", " 'feature_names',\n", " 'filename',\n", " 'frame',\n", " 'target',\n", " 'target_names']" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dir(iris)" ] }, { "cell_type": "code", "execution_count": 4, "id": "fa4945b9-1ad4-41f3-ad03-9047953fdfc6", "metadata": {}, "outputs": [], "source": [ "iris_data = iris[\"data\"]" ] }, { "cell_type": "code", "execution_count": 5, "id": "97d82f62-5653-4b8d-a57d-af3e43e3793c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(150, 4)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "iris_data.shape" ] }, { "cell_type": "code", "execution_count": 6, "id": "7deb6bbe-c7ab-460f-ace9-eabe33d08c24", "metadata": {}, "outputs": [], "source": [ "iris_target = iris[\"target\"]" ] }, { "cell_type": "code", "execution_count": 7, "id": "327beb9b-e0b7-43ad-9703-f2a6239d47e9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(150,)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "iris_target.shape" ] }, { "cell_type": "code", "execution_count": 8, "id": "a7649a22-7433-4d01-a237-01b8aa9ab85f", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "code", "execution_count": 9, "id": "4f5b6339-0d0b-4ed8-8320-0796ea0193ef", "metadata": {}, "outputs": [], "source": [ "x_train, x_test, y_train, y_test = train_test_split(iris_data, iris_target, test_size=0.2, random_state=10)" ] }, { "cell_type": "code", "execution_count": 10, "id": "97a6b729-80ac-4a0e-8220-e9af7f997581", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(150, 120, 30)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(iris_data), len(x_train), len(x_test)" ] }, { "cell_type": "code", "execution_count": 11, "id": "986f2787-c3c1-4f49-8f21-561aa5b9d880", "metadata": {}, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 12, "id": "9bc3d836-e866-4c81-9197-e1b3afc6ddc5", "metadata": {}, "outputs": [], "source": [ "def compute_euclidean_distance(x, dataset):\n", " distance = np.sum(np.sqrt((dataset - x) ** 2), axis=1, keepdims=True)\n", " return distance" ] }, { "cell_type": "code", "execution_count": 13, "id": "e3d5836d-a5bc-4098-a915-d493bce4630a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "120" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(x_train)" ] }, { "cell_type": "code", "execution_count": 20, "id": "2f221d57-3fbb-45e2-90dd-ad3ca935da05", "metadata": {}, "outputs": [], "source": [ "def knn(x_test, x_train, labels, k):\n", " y_preds = []\n", " if len(x_test.shape) == 1:#测试单个数据\n", " x_test = np.expand_dims(x_test, axis=0)#扩展维度\n", " for i in range(x_test.shape[0]):\n", " distances = compute_euclidean_distance(x_test[i], x_train)#计算欧式距离\n", " indexes = distances.argsort(axis=0)#距离排序\n", " k_most = indexes[:k]#取前k个最短距离\n", " labels_most = labels[k_most].flatten()#二维转一维\n", " y_preds.append(np.argmax(np.bincount(labels_most)))#选出出现次数最多的标签,类似于投票\n", " return np.array(y_preds)" ] }, { "cell_type": "code", "execution_count": 22, "id": "3ee80916-e5ae-4b7b-b83d-2019b86baba6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 2, 0, 1, 0, 1, 2, 1, 0, 1, 1, 2, 1, 0, 0, 2, 1, 0, 0, 0, 2, 2,\n", " 2, 0, 1, 0, 1, 1, 1, 2], dtype=int64)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "knn(x_test, x_train, y_train, 3)" ] }, { "cell_type": "code", "execution_count": 28, "id": "90b3ce57-aeed-489a-9135-edbe32943c61", "metadata": {}, "outputs": [], "source": [ "def accuracy(y_preds, y_true):\n", " return (y_preds == y_true).sum() / len(y_true)" ] }, { "cell_type": "code", "execution_count": 37, "id": "031c3693-e54a-458a-9596-86775ca785ac", "metadata": {}, "outputs": [], "source": [ "y_preds = knn(x_test, x_train, y_train, 8)" ] }, { "cell_type": "code", "execution_count": 38, "id": "75db42e7-1060-462f-b3be-0a52a491d34a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.0" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "accuracy(y_preds, y_test)" ] }, { "cell_type": "code", "execution_count": 56, "id": "23d91fd3-e21e-4987-b710-b95f90bfb1f9", "metadata": {}, "outputs": [], "source": [ "from sklearn.neighbors import KNeighborsClassifier" ] }, { "cell_type": "code", "execution_count": 60, "id": "c9db8922-65a3-45cd-89d5-9e70349c7570", "metadata": {}, "outputs": [], "source": [ "nb = KNeighborsClassifier(n_neighbors=8, p=0.5)" ] }, { "cell_type": "code", "execution_count": 61, "id": "cf6f1f7d-ce1e-4f55-a8d9-67d27a1e91e2", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "D:\\envs\\stark\\lib\\site-packages\\sklearn\\neighbors\\_base.py:632: UserWarning: Mind that for 0 < p < 1, Minkowski metrics are not distance metrics. Continuing the execution with `algorithm='brute'`.\n", " warnings.warn(\n" ] }, { "data": { "text/html": [ "
KNeighborsClassifier(n_neighbors=8, p=0.5)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "KNeighborsClassifier(n_neighbors=8, p=0.5)" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nb.fit(x_train, y_train)" ] }, { "cell_type": "code", "execution_count": 62, "id": "6a2cd011-6d8c-4fa9-a5cd-863142341a02", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.0" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nb.score(x_test, y_test)" ] }, { "cell_type": "code", "execution_count": null, "id": "ce164fb4-b5eb-45e7-a8cf-46799589f88c", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "64e44923-7989-45ba-9ac8-a0b259568ae9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.0" } }, "nbformat": 4, "nbformat_minor": 5 }