You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
339 lines
6.2 KiB
Plaintext
339 lines
6.2 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9aa453ff",
|
|
"metadata": {},
|
|
"source": [
|
|
"e ** x = b"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "ee6590f9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import math"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "a0481e20",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"1.6486586255873816\n",
|
|
"5.199999999999999\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"b = 5.2\n",
|
|
"print(np.log(b))\n",
|
|
"print(math.e ** 1.6486586255873816)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "64590073",
|
|
"metadata": {},
|
|
"source": [
|
|
"交叉熵损失函数"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0a45646b",
|
|
"metadata": {},
|
|
"source": [
|
|
"class: 3 \n",
|
|
"label: 0 \n",
|
|
"one-hot: [1, 0, 0] \n",
|
|
"prediction: [0.7, 0.1, 0.2] \n",
|
|
"$$L = -\\sum_jy_ilog(\\hat{y}) = -(1\\cdot\\log(0.7) + 0\\cdot\\log(0.1) + 0\\cdot\\log(0.2))$$"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "bbf3bd0e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"softmax_output = [0.7, 0.1, 0.2]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "44c25853",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"0.35667494393873245\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"target_output = [1, 0, 0]#独热编码\n",
|
|
"loss = -(math.log(softmax_output[0]) * target_output[0] +\n",
|
|
" math.log(softmax_output[1]) * target_output[1] +\n",
|
|
" math.log(softmax_output[2]) * target_output[2])\n",
|
|
"print(loss)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4a9cebb9",
|
|
"metadata": {},
|
|
"source": [
|
|
"简化运算"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "d07da7f4",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.35667494393873245"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"loss = -(math.log(softmax_output[0]) * target_output[0])\n",
|
|
"loss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "00ca4e6a",
|
|
"metadata": {},
|
|
"source": [
|
|
"多个数据"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "3b3b7290",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[0.7 0.5 0.9]\n",
|
|
"[0.7 0.5 0.9]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"softmax_outputs = np.array([[0.7, 0.1, 0.2],\n",
|
|
" [0.1, 0.5, 0.4],\n",
|
|
" [0.02, 0.9, 0.08]])\n",
|
|
"class_targets = [0, 1, 1]\n",
|
|
"print(softmax_outputs[[0, 1, 2],class_targets])\n",
|
|
"print(softmax_outputs[range(len(softmax_outputs)),class_targets])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "e639666d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"neg_log = -np.log(softmax_outputs[range(len(softmax_outputs)),class_targets])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "636b4d4f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([0.35667494, 0.69314718, 0.10536052])"
|
|
]
|
|
},
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"neg_log"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "ff5eee5d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"avg_loss = np.mean(neg_log)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "ab993d9f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.38506088005216804"
|
|
]
|
|
},
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"avg_loss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2d026b79",
|
|
"metadata": {},
|
|
"source": [
|
|
"考虑np.log(0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "841620c0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"D:\\envs\\stark-lin\\lib\\site-packages\\ipykernel_launcher.py:1: RuntimeWarning: divide by zero encountered in log\n",
|
|
" \"\"\"Entry point for launching an IPython kernel.\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"-inf"
|
|
]
|
|
},
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"np.log(0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "45c2d27e",
|
|
"metadata": {},
|
|
"source": [
|
|
"对输出进行裁剪"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "e7642073",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(16.11809565095832, 1.0000000494736474e-07)"
|
|
]
|
|
},
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"-np.log(1e-7), -np.log(1-1e-7)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"id": "7ace302a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"16.11809565095832"
|
|
]
|
|
},
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"y_pred = 0\n",
|
|
"y_pred_clip = np.clip(y_pred, 1e-7, 1-1e-7)\n",
|
|
"-np.log(y_pred_clip)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a3083d8f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|