{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "8f3ff90a", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import nnfs\n", "from nnfs.datasets import spiral_data\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 2, "id": "439ab0e4", "metadata": {}, "outputs": [], "source": [ "nnfs.init()" ] }, { "cell_type": "code", "execution_count": 39, "id": "2bc1bb6c", "metadata": {}, "outputs": [], "source": [ "class Linear():\n", " def __init__(self, in_features, out_features):\n", " self.weights = np.random.randn(in_features, out_features) * 0.1\n", " self.bias = np.zeros((1, out_features))\n", " def forward(self, inputs):\n", " self.outputs = np.dot(inputs, self.weights) + self.bias\n", " self.inputs = inputs\n", " def backward(self, dvalues):\n", " #计算权重的梯度\n", " self.dweights = np.dot(self.inputs.T, dvalues)\n", " #计算偏置的梯度\n", " self.dbiases = np.sum(dvalues, axis=0, keepdims=True)\n", " #计算输入的梯度\n", " self.dinputs = np.dot(dvalues, self.weights.T)\n", "class Relu():\n", " def forward(self, inputs):\n", " self.outputs = np.maximum(0, inputs)\n", " self.inputs = inputs\n", " def backward(self, dvalues):\n", " self.dinputs = dvalues.copy()\n", " self.dinputs[self.inputs <= 0] = 0\n", "class Softmax():\n", " def forward(self, inputs):\n", " exp_values = np.exp(inputs - np.max(inputs, axis=1, keepdims=True))\n", " self.outputs = exp_values / np.sum(exp_values, axis=1, keepdims=True)\n", " def backward(self, dvalues):\n", " self.dinputs = np.empty_like(dvalues)\n", " for index, (single_output, single_dvalues) in enumerate(zip(self.outputs, dvalues)):\n", " single_output = single_output.reshape(-1, 1)\n", " jacobian_matrix = np.diagflat(single_output) - np.dot(single_output, single_output.T)\n", " self.dinputs[index] = np.dot(jacobian_matrix, single_dvalues)\n", "class Loss():\n", " def calculate(self, outputs, y):\n", " sample_loss = self.forward(outputs, y)\n", " data_loss = np.mean(sample_loss)\n", " return data_loss\n", "\"\"\"交叉熵损失函数\"\"\"\n", "class CategoricalCrossEntropy(Loss):\n", " def one_hot(self, y_true):\n", " y_true = np.eye(max(y_true) + 1)[y_true]\n", " return y_true\n", " def forward(self, y_pred, y_true):\n", " y_pred_clipped = np.clip(y_pred, 1e-7, 1-1e-7)#截断,放置np.log溢出\n", " if len(y_true.shape) == 1:#标签不是独热码,通过one_hot函数转化\n", " y_true = self.one_hot(y_true)\n", " confidences = np.sum(y_pred_clipped * y_true, axis=1)\n", " negative_log_likehoods = -np.log(confidences)\n", " return negative_log_likehoods\n", " def backward(self, dvalues, y_true):\n", " samples = len(dvalues)\n", " if len(y_true.shape) == 1:#将标签转化为onehot\n", " y_true = self.one_hot(y_true)\n", " self.dinputs = -y_true / dvalues\n", " self.dinputs = self.dinputs /samples\n", "\"\"\"准确率\"\"\"\n", "class Accuracy():\n", " def forward(self, logits, y_true):\n", " predictions = np.argmax(logits, axis=1)#从激活输出选出最大值对应的下标即使标签\n", " if len(y_true.shape) == 2:#独热编码\n", " y_true = np.max(y_true, axis=1)\n", " accuracy = np.mean(predictions == y_true)\n", " return accuracy" ] }, { "cell_type": "code", "execution_count": 41, "id": "5f334df4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(300, 3)\n", "[[-0.00222222 0.00111111 0.00111111]\n", " [-0.00222225 0.00111109 0.00111116]\n", " [-0.00222228 0.00111107 0.00111121]\n", " [-0.0022223 0.00111105 0.00111125]\n", " [-0.00222237 0.001111 0.00111137]\n", " [-0.00222236 0.00111101 0.00111135]\n", " [-0.0022224 0.00111098 0.00111142]\n", " [-0.00222249 0.00111092 0.00111157]\n", " [-0.00222261 0.00111093 0.00111168]\n", " [-0.00222254 0.00111087 0.00111167]\n", " [-0.00222249 0.0011109 0.00111159]\n", " [-0.00222283 0.00111091 0.00111191]\n", " [-0.00222279 0.00111083 0.00111196]\n", " [-0.00222299 0.00111105 0.00111194]\n", " [-0.00222305 0.001111 0.00111206]\n", " [-0.00222307 0.00111088 0.0011122 ]\n", " [-0.00222317 0.00111096 0.00111221]\n", " [-0.00222318 0.00111083 0.00111235]\n", " [-0.00222208 0.00111125 0.00111082]\n", " [-0.00222333 0.00111087 0.00111246]\n", " [-0.0022217 0.00111131 0.00111039]\n", " [-0.00222209 0.00111127 0.00111082]\n", " [-0.00222284 0.00111118 0.00111166]\n", " [-0.00222255 0.00111123 0.00111132]\n", " [-0.0022201 0.00111003 0.00111007]\n", " [-0.00222371 0.00111093 0.00111278]\n", " [-0.00221758 0.0011059 0.00111168]\n", " [-0.00222071 0.00111093 0.00110978]\n", " [-0.0022197 0.00110978 0.00110992]\n", " [-0.00221763 0.00110665 0.00111098]\n", " [-0.00221808 0.00110754 0.00111054]\n", " [-0.00221782 0.00110463 0.00111319]\n", " [-0.00221667 0.00110517 0.0011115 ]\n", " [-0.00221904 0.00110925 0.00110979]\n", " [-0.00221664 0.00110549 0.00111115]\n", " [-0.00221663 0.00110352 0.00111311]\n", " [-0.00221696 0.00110634 0.00111062]\n", " [-0.00221568 0.0011039 0.00111178]\n", " [-0.00222012 0.00110572 0.00111439]\n", " [-0.00221633 0.00110278 0.00111354]\n", " [-0.00221551 0.00110421 0.0011113 ]\n", " [-0.00221524 0.00110213 0.00111312]\n", " [-0.00221796 0.00110348 0.00111447]\n", " [-0.00222175 0.00110951 0.00111224]\n", " [-0.00221539 0.00110164 0.00111375]\n", " [-0.00221868 0.00110379 0.00111489]\n", " [-0.00221835 0.00110341 0.00111494]\n", " [-0.00221633 0.00110177 0.00111457]\n", " [-0.00221601 0.00110144 0.00111457]\n", " [-0.00221604 0.00110134 0.0011147 ]\n", " [-0.00221384 0.00110017 0.00111367]\n", " [-0.00222003 0.00110448 0.00111555]\n", " [-0.0022214 0.0011083 0.0011131 ]\n", " [-0.00222196 0.00111027 0.00111169]\n", " [-0.00222042 0.00110484 0.00111557]\n", " [-0.00222265 0.00111079 0.00111186]\n", " [-0.00222377 0.00110993 0.00111384]\n", " [-0.00222222 0.00111111 0.00111111]\n", " [-0.00222292 0.00111058 0.00111233]\n", " [-0.00222373 0.00110996 0.00111376]\n", " [-0.00222479 0.00110956 0.00111523]\n", " [-0.00222475 0.0011095 0.00111525]\n", " [-0.00222382 0.0011099 0.00111392]\n", " [-0.00222329 0.0011103 0.00111298]\n", " [-0.00222453 0.00110935 0.00111518]\n", " [-0.0022246 0.0011093 0.00111529]\n", " [-0.00222358 0.00111008 0.00111351]\n", " [-0.00222428 0.00110954 0.00111474]\n", " [-0.002223 0.00111052 0.00111249]\n", " [-0.00222499 0.00110926 0.00111574]\n", " [-0.00222475 0.00110919 0.00111556]\n", " [-0.00222592 0.00110965 0.00111627]\n", " [-0.00222632 0.00111 0.00111632]\n", " [-0.00222651 0.00111027 0.00111624]\n", " [-0.00222543 0.00111118 0.00111425]\n", " [-0.00222668 0.00111065 0.00111603]\n", " [-0.00222674 0.00111049 0.00111625]\n", " [-0.00221992 0.00111193 0.00110799]\n", " [-0.00222683 0.0011108 0.00111603]\n", " [-0.00222691 0.0011107 0.00111621]\n", " [-0.00222658 0.00111106 0.00111552]\n", " [-0.00222697 0.00111096 0.001116 ]\n", " [-0.00222709 0.00111055 0.00111654]\n", " [-0.00221433 0.00110659 0.00110774]\n", " [-0.00222074 0.00111188 0.00110886]\n", " [-0.0022083 0.00109718 0.00111111]\n", " [-0.00220976 0.00109994 0.00110982]\n", " [-0.00220732 0.00109548 0.00111184]\n", " [-0.00222172 0.00111178 0.00110994]\n", " [-0.00220619 0.00109271 0.00111347]\n", " [-0.00222182 0.00111179 0.00111003]\n", " [-0.00222342 0.00111158 0.00111183]\n", " [-0.00220705 0.00109585 0.0011112 ]\n", " [-0.0022114 0.00110319 0.0011082 ]\n", " [-0.00220616 0.00109434 0.00111182]\n", " [-0.00220555 0.00109035 0.0011152 ]\n", " [-0.00220542 0.001093 0.00111241]\n", " [-0.00220477 0.00109009 0.00111468]\n", " [-0.00220454 0.00109012 0.00111441]\n", " [-0.00220887 0.00109087 0.00111799]\n", " [ 0.00111111 -0.00222222 0.00111111]\n", " [ 0.00111129 -0.00222244 0.00111115]\n", " [ 0.00111138 -0.00222263 0.00111125]\n", " [ 0.00111132 -0.00222268 0.00111137]\n", " [ 0.00111165 -0.00222305 0.00111139]\n", " [ 0.00111177 -0.00222324 0.00111147]\n", " [ 0.00111177 -0.00222335 0.00111158]\n", " [ 0.00111163 -0.00222334 0.0011117 ]\n", " [ 0.0011121 -0.00222381 0.0011117 ]\n", " [ 0.00111236 -0.00222409 0.00111173]\n", " [ 0.00111203 -0.00222396 0.00111193]\n", " [ 0.00111138 -0.00222315 0.00111177]\n", " [ 0.001112 -0.00222413 0.00111213]\n", " [ 0.00111114 -0.00222232 0.00111118]\n", " [ 0.00111102 -0.00222229 0.00111126]\n", " [ 0.00111153 -0.00222368 0.00111215]\n", " [ 0.00111125 -0.00222268 0.00111143]\n", " [ 0.00111131 -0.00222288 0.00111157]\n", " [ 0.00111111 -0.00222222 0.00111111]\n", " [ 0.00111111 -0.00222223 0.00111112]\n", " [ 0.00111078 -0.00222247 0.00111169]\n", " [ 0.00111066 -0.00222256 0.0011119 ]\n", " [ 0.00111076 -0.00222249 0.00111172]\n", " [ 0.00111112 -0.00222225 0.00111113]\n", " [ 0.00111051 -0.00222268 0.00111217]\n", " [ 0.00111065 -0.00222257 0.00111192]\n", " [ 0.0011108 -0.00222246 0.00111166]\n", " [ 0.00111021 -0.00222291 0.0011127 ]\n", " [ 0.00111024 -0.00222288 0.00111264]\n", " [ 0.00110954 -0.00222277 0.00111322]\n", " [ 0.00110998 -0.00222305 0.00111307]\n", " [ 0.0011101 -0.00222299 0.00111289]\n", " [ 0.00110922 -0.00222254 0.00111332]\n", " [ 0.00110919 -0.00222264 0.00111345]\n", " [ 0.00110976 -0.00222314 0.00111338]\n", " [ 0.00110939 -0.00222301 0.00111363]\n", " [ 0.00110906 -0.00222277 0.00111371]\n", " [ 0.00110893 -0.00222263 0.00111369]\n", " [ 0.00111045 -0.00222205 0.00111159]\n", " [ 0.0011088 -0.0022224 0.0011136 ]\n", " [ 0.00111287 -0.00222205 0.00110918]\n", " [ 0.00110869 -0.00222266 0.00111397]\n", " [ 0.00110907 -0.00222221 0.00111314]\n", " [ 0.00111237 -0.00222177 0.00110939]\n", " [ 0.00110851 -0.00222239 0.00111388]\n", " [ 0.00111267 -0.0022218 0.00110914]\n", " [ 0.00111772 -0.00222811 0.00111039]\n", " [ 0.00111507 -0.00222407 0.001109 ]\n", " [ 0.00111223 -0.00222175 0.00110952]\n", " [ 0.00111217 -0.00222175 0.00110958]\n", " [ 0.00110896 -0.00222217 0.00111321]\n", " [ 0.00111569 -0.00222462 0.00110893]\n", " [ 0.00112012 -0.00223362 0.0011135 ]\n", " [ 0.00111823 -0.0022282 0.00110997]\n", " [ 0.00112029 -0.00223179 0.00111149]\n", " [ 0.00112103 -0.00223379 0.00111275]\n", " [ 0.00112051 -0.00223188 0.00111137]\n", " [ 0.00111947 -0.0022298 0.00111033]\n", " [ 0.00112153 -0.00223408 0.00111255]\n", " [ 0.0011202 -0.00223488 0.00111468]\n", " [ 0.00112156 -0.00223342 0.00111187]\n", " [ 0.00112097 -0.00223548 0.0011145 ]\n", " [ 0.0011212 -0.00223571 0.00111451]\n", " [ 0.00112241 -0.00223594 0.00111353]\n", " [ 0.00112234 -0.00223439 0.00111205]\n", " [ 0.00111672 -0.00223322 0.0011165 ]\n", " [ 0.00111621 -0.00223287 0.00111667]\n", " [ 0.00111912 -0.00223527 0.00111615]\n", " [ 0.0011209 -0.00223649 0.00111558]\n", " [ 0.00112355 -0.00223708 0.00111353]\n", " [ 0.00111236 -0.0022265 0.00111414]\n", " [ 0.00112337 -0.00223777 0.0011144 ]\n", " [ 0.00112003 -0.00223646 0.00111642]\n", " [ 0.00111085 -0.00222242 0.00111157]\n", " [ 0.00111603 -0.00223346 0.00111743]\n", " [ 0.00111307 -0.00222898 0.00111592]\n", " [ 0.00111746 -0.0022349 0.00111744]\n", " [ 0.00111351 -0.00223055 0.00111704]\n", " [ 0.00111074 -0.0022225 0.00111176]\n", " [ 0.00111111 -0.00222222 0.00111111]\n", " [ 0.00110846 -0.00222424 0.00111578]\n", " [ 0.00111205 -0.00222538 0.00111334]\n", " [ 0.00111066 -0.00222256 0.0011119 ]\n", " [ 0.00111053 -0.00222267 0.00111214]\n", " [ 0.00111111 -0.00222222 0.00111111]\n", " [ 0.00110925 -0.00222364 0.0011144 ]\n", " [ 0.0011096 -0.00222337 0.00111376]\n", " [ 0.00111099 -0.00222231 0.00111132]\n", " [ 0.00110797 -0.00222461 0.00111664]\n", " [ 0.00110941 -0.00222351 0.0011141 ]\n", " [ 0.00110854 -0.00222418 0.00111563]\n", " [ 0.00110901 -0.00222382 0.00111481]\n", " [ 0.00110915 -0.00222372 0.00111457]\n", " [ 0.00110876 -0.00222401 0.00111526]\n", " [ 0.00110648 -0.00222435 0.00111787]\n", " [ 0.00110789 -0.00222467 0.00111678]\n", " [ 0.00110597 -0.00222408 0.00111811]\n", " [ 0.00110578 -0.00222396 0.00111818]\n", " [ 0.00110529 -0.00222293 0.00111764]\n", " [ 0.00110754 -0.00222494 0.00111739]\n", " [ 0.00111111 0.00111111 -0.00222222]\n", " [ 0.00111106 0.00111109 -0.00222215]\n", " [ 0.00111102 0.00111111 -0.00222213]\n", " [ 0.00111098 0.00111111 -0.0022221 ]\n", " [ 0.00111101 0.00111113 -0.00222214]\n", " [ 0.00111088 0.00111111 -0.002222 ]\n", " [ 0.00111088 0.00111095 -0.00222182]\n", " [ 0.00111072 0.001111 -0.00222172]\n", " [ 0.00111151 0.0011111 -0.00222261]\n", " [ 0.00111072 0.00111112 -0.00222184]\n", " [ 0.0011122 0.00111036 -0.00222257]\n", " [ 0.00111047 0.00111109 -0.00222156]\n", " [ 0.00111258 0.00110998 -0.00222256]\n", " [ 0.0011125 0.00111018 -0.00222268]\n", " [ 0.00111292 0.00110965 -0.00222256]\n", " [ 0.00111285 0.00110984 -0.00222269]\n", " [ 0.00111169 0.00111125 -0.00222294]\n", " [ 0.00111294 0.00110988 -0.00222282]\n", " [ 0.00111417 0.00110792 -0.00222209]\n", " [ 0.00111301 0.00110995 -0.00222296]\n", " [ 0.00111459 0.00110672 -0.00222131]\n", " [ 0.00111475 0.0011065 -0.00222125]\n", " [ 0.00111442 0.00110642 -0.00222084]\n", " [ 0.00111502 0.00110607 -0.00222109]\n", " [ 0.00111438 0.00110833 -0.00222271]\n", " [ 0.00111548 0.00110563 -0.00222111]\n", " [ 0.0011147 0.00110573 -0.00222043]\n", " [ 0.00111598 0.00110541 -0.00222139]\n", " [ 0.00111614 0.00110537 -0.00222151]\n", " [ 0.00111624 0.00110547 -0.00222171]\n", " [ 0.00111652 0.00110466 -0.00222118]\n", " [ 0.00111416 0.00110556 -0.00221972]\n", " [ 0.00111421 0.00110542 -0.00221963]\n", " [ 0.00111151 0.00110975 -0.00222127]\n", " [ 0.00111356 0.00110578 -0.00221933]\n", " [ 0.00111354 0.0011057 -0.00221924]\n", " [ 0.00111189 0.00110844 -0.00222033]\n", " [ 0.00111473 0.0011045 -0.00221923]\n", " [ 0.00111111 0.00111111 -0.00222222]\n", " [ 0.00111111 0.00111111 -0.00222222]\n", " [ 0.00111348 0.0011053 -0.00221878]\n", " [ 0.00111116 0.00111097 -0.00222213]\n", " [ 0.00111149 0.00110984 -0.00222133]\n", " [ 0.00111108 0.00111109 -0.00222216]\n", " [ 0.0011102 0.00111042 -0.00222062]\n", " [ 0.00111141 0.00111014 -0.00222154]\n", " [ 0.00111154 0.00110966 -0.00222121]\n", " [ 0.00110981 0.00111012 -0.00221992]\n", " [ 0.00111421 0.0011039 -0.00221811]\n", " [ 0.00111015 0.00111038 -0.00222052]\n", " [ 0.00110996 0.00111024 -0.0022202 ]\n", " [ 0.00110995 0.00111022 -0.00222017]\n", " [ 0.00110939 0.0011098 -0.00221919]\n", " [ 0.00111014 0.00111038 -0.00222052]\n", " [ 0.00110913 0.00110961 -0.00221874]\n", " [ 0.00110948 0.00110987 -0.00221935]\n", " [ 0.00110939 0.0011098 -0.0022192 ]\n", " [ 0.00110942 0.00110982 -0.00221924]\n", " [ 0.00110914 0.00110961 -0.00221875]\n", " [ 0.00110806 0.00110988 -0.00221794]\n", " [ 0.00111224 0.00111167 -0.0022239 ]\n", " [ 0.00110768 0.00111106 -0.00221873]\n", " [ 0.00110893 0.00111124 -0.00222017]\n", " [ 0.00110922 0.00110967 -0.00221889]\n", " [ 0.00110931 0.00111131 -0.00222062]\n", " [ 0.0011075 0.00110999 -0.00221749]\n", " [ 0.00110951 0.00111135 -0.00222087]\n", " [ 0.00110777 0.00110962 -0.00221739]\n", " [ 0.00111162 0.00111165 -0.00222326]\n", " [ 0.00110701 0.00111063 -0.00221764]\n", " [ 0.00110697 0.00111041 -0.00221739]\n", " [ 0.00111613 0.00110962 -0.00222575]\n", " [ 0.00112255 0.00109998 -0.00222253]\n", " [ 0.0011154 0.00111043 -0.00222583]\n", " [ 0.00111879 0.00110618 -0.00222496]\n", " [ 0.00112397 0.00109761 -0.00222158]\n", " [ 0.00110674 0.00111103 -0.00221777]\n", " [ 0.00111849 0.00110684 -0.00222533]\n", " [ 0.00111036 0.00111155 -0.00222191]\n", " [ 0.00111878 0.00110659 -0.00222537]\n", " [ 0.00112116 0.00110318 -0.00222434]\n", " [ 0.00112374 0.00109369 -0.00221742]\n", " [ 0.00112493 0.00109689 -0.00222182]\n", " [ 0.00112573 0.00109519 -0.00222092]\n", " [ 0.00112447 0.00109812 -0.00222259]\n", " [ 0.0011262 0.00109446 -0.00222066]\n", " [ 0.00112568 0.00109231 -0.00221799]\n", " [ 0.00112681 0.00109287 -0.00221968]\n", " [ 0.001127 0.00109251 -0.00221951]\n", " [ 0.00112619 0.00109166 -0.00221785]\n", " [ 0.0011269 0.00109405 -0.00222094]\n", " [ 0.00112576 0.00109672 -0.00222247]\n", " [ 0.00112763 0.0010924 -0.00222003]\n", " [ 0.00112169 0.0010934 -0.00221509]\n", " [ 0.00112547 0.00109102 -0.00221649]\n", " [ 0.00112268 0.00109248 -0.00221516]\n", " [ 0.00112286 0.00109225 -0.00221511]\n", " [ 0.00112159 0.00109304 -0.00221464]\n", " [ 0.00111771 0.00109615 -0.00221386]\n", " [ 0.00112716 0.00108963 -0.00221678]]\n" ] } ], "source": [ "if \"__main__\" == __name__:\n", " #数据(3,4)(batch_size, in_features)\n", " X, y = spiral_data(100, 3)#每个类别100个数据,3个类别就是300\n", " layer1 = Linear(2, 5)\n", " layer2 = Linear(5, 3)\n", " layer1.forward(X)\n", " act1 = Relu()\n", " act1.forward(layer1.outputs)\n", " layer2.forward(act1.outputs)\n", " act2 = Softmax()\n", " act2.forward(layer2.outputs)\n", " loss = CategoricalCrossEntropy()\n", " dvalues = loss.forward(act2.outputs, y)\n", " loss.backward(act2.outputs, y)\n", " print(loss.dinputs.shape)\n", " act2.backward(loss.dinputs)\n", " print(act2.dinputs)\n", "# cost = loss.calculate(act2.outputs, y)\n", "# print(cost)\n", "# acc = Accuracy()\n", "# accuracy = acc.forward(act2.outputs, y)\n", "# print(accuracy)" ] }, { "cell_type": "code", "execution_count": null, "id": "c588b453", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 5 }