You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

196 lines
5.3 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "e9b30084",
"metadata": {},
"source": [
"对softmax求导要比对分类的交叉熵损失函数求导要难得多。"
]
},
{
"cell_type": "markdown",
"id": "e45006fb",
"metadata": {},
"source": [
"$$S_{i,j}=\\frac{e^{z_{i,j}}}{\\sum^L_{l=1}e^{z_{i,l}}} \\rightarrow \\frac{\\partial S_{i,j}}{\\partial z_{i,k}}=\\frac{\\partial\\frac{e^z_{i,j}}{\\sum^L_{l=1}e^{z_{i,l}}}}{\\partial z_{i,k}}$$"
]
},
{
"cell_type": "markdown",
"id": "6c3ddae7",
"metadata": {},
"source": [
"$S_{i,j}$:第$i$个数据的softamx的第$j$个输出。 \n",
"$z:$是一个向量,上一层的输出。 \n",
"$z_{i,j}:$第$i$个数据,$z$向量中第$j$个输入。 \n",
"$L$:输入的数量。 \n",
"$z_{i,k}$:第$i$个数据,$z$向量中的第$k$个输入。"
]
},
{
"cell_type": "markdown",
"id": "7796014f",
"metadata": {},
"source": [
"回顾一下,对分数进行求导如下。"
]
},
{
"cell_type": "markdown",
"id": "81ecc7e6",
"metadata": {},
"source": [
"$$f(x) = \\frac{g(x)}{h(x)}\\rightarrow f'(x) = \\frac{g'(x)\\cdot h(x)-g(x)\\cdot h'(x)}{[h(x)]^2}$$"
]
},
{
"cell_type": "markdown",
"id": "24cc1e33",
"metadata": {},
"source": [
"那么我们对softmax进行求导。"
]
},
{
"cell_type": "markdown",
"id": "935665a2",
"metadata": {},
"source": [
"$$\\frac{\\partial S_{i,j}}{\\partial z_{i,k}}=\\frac{\\partial\\frac{e^z_{i,j}}{\\sum^L_{l=1}e^{z_{i,l}}}}{\\partial z_{i,k}} =\\frac{ \\frac{\\partial}{\\partial z_{i,k}}e^{z_{i,j}}\\cdot\\sum^L_{l=1}e^{z_{i,l}}-e^{z_{i,j}}\\cdot\\frac{\\partial}{\\partial z_{i,k}}\\sum^L_{l=1}e^{z_{i,l}}}{[\\sum^L_{l=1}e^{z_{i,l}}]^2}$$"
]
},
{
"cell_type": "markdown",
"id": "7baa2962",
"metadata": {},
"source": [
"我们先看右侧的公式:$\\frac{\\partial}{\\partial z_{i,k}}\\sum^L_{l=1}e^{z_{i,l}}$"
]
},
{
"cell_type": "markdown",
"id": "97811b46",
"metadata": {},
"source": [
"对指数函数$e^n$求导。\n",
"$$\\frac{d}{dn}e^n = e^n\\cdot\\frac{d}{dn}n=e^n\\cdot 1=e^n$$"
]
},
{
"cell_type": "markdown",
"id": "01238382",
"metadata": {},
"source": [
"$$\\frac{\\partial}{\\partial z_{i,k}}\\sum^L_{l=1}e^{z_{i,l}}=\\frac{\\partial}{\\partial z_{i,k}}e^{z_{i,1}} + \\frac{\\partial}{\\partial z_{i,k}}e^{z_{i,2}} + \\cdots \\frac{\\partial}{\\partial z_{i,k}}e^{z_{i,k}} + \\cdots + \\frac{\\partial}{\\partial z_{i,k}}e^{z_{i,L-1}} + \\frac{\\partial}{\\partial z_{i,k}}e^{z_{i,L}}\\\\\n",
"=0 + 0 + \\cdots + e^{z_{i,k}} + \\cdots + 0 + 0 = e^{z_{i,k}}$$"
]
},
{
"cell_type": "markdown",
"id": "a0adce34",
"metadata": {},
"source": [
"对加和符号求导,除了相关项,其他看出常数。"
]
},
{
"cell_type": "markdown",
"id": "cd2ebf37",
"metadata": {},
"source": [
"对$\\frac{\\partial}{\\partial z_{i,k}}e^{z_{i,j}}$进行求导。"
]
},
{
"cell_type": "markdown",
"id": "492aec4d",
"metadata": {},
"source": [
"$$\\frac{\\partial}{\\partial z_{i,k}}e^{z_{i,j}} = \\frac{e^{z_{i,j}}\\cdot \\sum^L_{l=1}e^{z_{i,l}} - e^{z_{i,j}}\\cdot e^{z_{i,k}}}{[\\sum^L_{l=1}e^{z_{i,l}}]^2}$$"
]
},
{
"cell_type": "markdown",
"id": "3dd002e2",
"metadata": {},
"source": [
"如果j=k。"
]
},
{
"cell_type": "markdown",
"id": "f0472238",
"metadata": {},
"source": [
"$$\\frac{e^{z_{i,j}}\\cdot \\sum^L_{l=1}e^{z_{i,l}} - e^{z_{i,j}}\\cdot e^{z_{i,k}}}{[\\sum^L_{l=1}e^{z_{i,l}}]^2} = \\frac{e^{z_{i,j}}}{\\sum^L_{l=1}}\\cdot\\frac{\\sum^L_{l=1} - e^{z_{i,k}}}{\\sum^L_{l=1}e^{z_{i,l}}}=\\\\\n",
"\\frac{e^{z_{i,j}}}{\\sum^{L}_{l=1}e^{z_{i,l}}}\\cdot (\\frac{\\sum^L_{l=1}e^{z_{i,l}}}{\\sum^L_{l=1}e^{z_{i,l}}}-\\frac{e^{z_{i,k}}}{\\sum^L_{l=1}e^{z_{i,l}}})=S_{i,j}\\cdot(1-S_{i,k})$$"
]
},
{
"cell_type": "markdown",
"id": "b6adb4c7",
"metadata": {},
"source": [
"如果j$\\neq k$。"
]
},
{
"cell_type": "markdown",
"id": "31985b0a",
"metadata": {},
"source": [
"$$\\frac{0\\cdot \\sum^L_{l=1}e^{z_{i,l}} - e^{z_{i,j}}\\cdot e^{z_{i,k}}}{[\\sum^L_{l=1}e^{z_{i,l}}]^2} = \\frac{-e^{z_{i,j}}}{\\sum^L_{l=1}e^{z_{i,l}}}\\cdot \\frac{e^{z_{i,k}}}{\\sum^L_{l=1}e^{z_{i,l}} } = -S_{i,j}\\cdot S_{i,k}$$"
]
},
{
"cell_type": "markdown",
"id": "795a4717",
"metadata": {},
"source": [
"综上。"
]
},
{
"cell_type": "markdown",
"id": "409d8390",
"metadata": {},
"source": [
"$$\\frac{\\partial S_{i,j}}{\\partial z_{i,k}}=\\begin{cases}\n",
"S_{i,j}(1-S_{i,k})& \\text{j=k}\\\\\n",
"-S_{i,j}\\cdot S_{i,k}& \\text{j $\\neq$ k}\n",
"\\end{cases}$$"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9dfea11",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}