You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
410 lines
7.7 KiB
Plaintext
410 lines
7.7 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b76a974b",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 1. python code"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d5615c35",
|
|
"metadata": {},
|
|
"source": [
|
|
"<img src='softmax.png' width=60% align=center>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "70d78da6",
|
|
"metadata": {},
|
|
"source": [
|
|
"$$S_{i,j} = \\frac{e^{z_{i,j}}}{\\sum_{l=1}^Le^{z_{i,j}}}$$"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "9d932179",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import math"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "14179b41",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"outputs = [2.12, 3.14, -2]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "ace8866d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"e = math.e"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "e0e8fcc5",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[8.331137487687691, 23.10386685872218, 0.1353352832366127]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"exp_output = []\n",
|
|
"for output in outputs:\n",
|
|
" exp_output.append(e**output)\n",
|
|
"print(exp_output)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "cb818c38",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[0.26389128483952834, 0.7318219293727912, 0.0042867857876804265] 0.9999999999999999\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"base = sum(exp_output)\n",
|
|
"values = []\n",
|
|
"for value in exp_output:\n",
|
|
" values.append(value / base)\n",
|
|
"print(values, sum(values))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "f4cb814b",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 2. numpy code"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "6d2ec344",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "10cc6ec3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"e = np.exp(outputs)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "1f43e475",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"values = e / np.sum(e, keepdims=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "2c9d7f21",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(array([0.26389128, 0.73182193, 0.00428679]), 1.0)"
|
|
]
|
|
},
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"values, np.sum(values)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3c73bb74",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 3.batch size"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "3e93a7e2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"outputs = [[2.7,3.05,4.2],\n",
|
|
" [2.5, 4.2, 3],\n",
|
|
" [5.0, 1.2, 2.1]]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"id": "d3b150e5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"exp_outputs = np.exp(outputs)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"id": "0f466cc2",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([[ 14.87973172, 21.11534442, 66.68633104],\n",
|
|
" [ 12.18249396, 66.68633104, 20.08553692],\n",
|
|
" [148.4131591 , 3.32011692, 8.16616991]])"
|
|
]
|
|
},
|
|
"execution_count": 26,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"exp_outputs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "3e3e9f8f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"bases = np.sum(exp_outputs, axis=1, keepdims=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 28,
|
|
"id": "822c0d2d",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([[102.68140719],\n",
|
|
" [ 98.95436192],\n",
|
|
" [159.89944594]])"
|
|
]
|
|
},
|
|
"execution_count": 28,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"bases"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"id": "89f8c6e3",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(array([[0.14491165, 0.20563941, 0.64944894],\n",
|
|
" [0.12311225, 0.67390997, 0.20297778],\n",
|
|
" [0.92816556, 0.02076378, 0.05107066]]),\n",
|
|
" array([1., 1., 1.]))"
|
|
]
|
|
},
|
|
"execution_count": 29,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"soft_outputs = exp_outputs / bases\n",
|
|
"soft_outputs, np.sum(soft_outputs, axis=1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d2f755cf",
|
|
"metadata": {},
|
|
"source": [
|
|
"### 4.overflow prevention"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c4356909",
|
|
"metadata": {},
|
|
"source": [
|
|
"$$v = u - max(u)$$"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "958e2993",
|
|
"metadata": {},
|
|
"source": [
|
|
"如果遇到值很大的情况。"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"id": "31d758fc",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"D:\\envs\\stark-lin\\lib\\site-packages\\ipykernel_launcher.py:1: RuntimeWarning: overflow encountered in exp\n",
|
|
" \"\"\"Entry point for launching an IPython kernel.\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"inf"
|
|
]
|
|
},
|
|
"execution_count": 30,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"np.exp(1000)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ea69745c",
|
|
"metadata": {},
|
|
"source": [
|
|
"所以需要归一化"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 33,
|
|
"id": "6ab89eb9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"outputs = [[2.7,3.05,1000],\n",
|
|
" [2.5, 4.2, 3],\n",
|
|
" [5.0, 1.2, 2.1]]\n",
|
|
"outputs = np.array(outputs)\n",
|
|
"minus_outputs = outputs - np.max(outputs, axis=1, keepdims=True)\n",
|
|
"exp_outputs = np.exp(minus_outputs)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 34,
|
|
"id": "3a8dc68f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(array([[0. , 0. , 1. ],\n",
|
|
" [0.12311225, 0.67390997, 0.20297778],\n",
|
|
" [0.92816556, 0.02076378, 0.05107066]]),\n",
|
|
" array([1., 1., 1.]))"
|
|
]
|
|
},
|
|
"execution_count": 34,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"minus_softmax = exp_outputs / np.sum(exp_outputs, axis=1, keepdims=True)\n",
|
|
"minus_softmax, np.sum(minus_softmax, axis=1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a5081448",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1b223fc9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|