You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

410 lines
7.7 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "b76a974b",
"metadata": {},
"source": [
"### 1. python code"
]
},
{
"cell_type": "markdown",
"id": "d5615c35",
"metadata": {},
"source": [
"<img src='softmax.png' width=60% align=center>"
]
},
{
"cell_type": "markdown",
"id": "70d78da6",
"metadata": {},
"source": [
"$$S_{i,j} = \\frac{e^{z_{i,j}}}{\\sum_{l=1}^Le^{z_{i,j}}}$$"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9d932179",
"metadata": {},
"outputs": [],
"source": [
"import math"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "14179b41",
"metadata": {},
"outputs": [],
"source": [
"outputs = [2.12, 3.14, -2]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ace8866d",
"metadata": {},
"outputs": [],
"source": [
"e = math.e"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e0e8fcc5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[8.331137487687691, 23.10386685872218, 0.1353352832366127]\n"
]
}
],
"source": [
"exp_output = []\n",
"for output in outputs:\n",
" exp_output.append(e**output)\n",
"print(exp_output)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "cb818c38",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.26389128483952834, 0.7318219293727912, 0.0042867857876804265] 0.9999999999999999\n"
]
}
],
"source": [
"base = sum(exp_output)\n",
"values = []\n",
"for value in exp_output:\n",
" values.append(value / base)\n",
"print(values, sum(values))"
]
},
{
"cell_type": "markdown",
"id": "f4cb814b",
"metadata": {},
"source": [
"### 2. numpy code"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6d2ec344",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "10cc6ec3",
"metadata": {},
"outputs": [],
"source": [
"e = np.exp(outputs)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "1f43e475",
"metadata": {},
"outputs": [],
"source": [
"values = e / np.sum(e, keepdims=True)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "2c9d7f21",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([0.26389128, 0.73182193, 0.00428679]), 1.0)"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"values, np.sum(values)"
]
},
{
"cell_type": "markdown",
"id": "3c73bb74",
"metadata": {},
"source": [
"### 3.batch size"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "3e93a7e2",
"metadata": {},
"outputs": [],
"source": [
"outputs = [[2.7,3.05,4.2],\n",
" [2.5, 4.2, 3],\n",
" [5.0, 1.2, 2.1]]"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "d3b150e5",
"metadata": {},
"outputs": [],
"source": [
"exp_outputs = np.exp(outputs)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "0f466cc2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 14.87973172, 21.11534442, 66.68633104],\n",
" [ 12.18249396, 66.68633104, 20.08553692],\n",
" [148.4131591 , 3.32011692, 8.16616991]])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"exp_outputs"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "3e3e9f8f",
"metadata": {},
"outputs": [],
"source": [
"bases = np.sum(exp_outputs, axis=1, keepdims=True)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "822c0d2d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[102.68140719],\n",
" [ 98.95436192],\n",
" [159.89944594]])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bases"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "89f8c6e3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[0.14491165, 0.20563941, 0.64944894],\n",
" [0.12311225, 0.67390997, 0.20297778],\n",
" [0.92816556, 0.02076378, 0.05107066]]),\n",
" array([1., 1., 1.]))"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"soft_outputs = exp_outputs / bases\n",
"soft_outputs, np.sum(soft_outputs, axis=1)"
]
},
{
"cell_type": "markdown",
"id": "d2f755cf",
"metadata": {},
"source": [
"### 4.overflow prevention"
]
},
{
"cell_type": "markdown",
"id": "c4356909",
"metadata": {},
"source": [
"$$v = u - max(u)$$"
]
},
{
"cell_type": "markdown",
"id": "958e2993",
"metadata": {},
"source": [
"如果遇到值很大的情况。"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "31d758fc",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"D:\\envs\\stark-lin\\lib\\site-packages\\ipykernel_launcher.py:1: RuntimeWarning: overflow encountered in exp\n",
" \"\"\"Entry point for launching an IPython kernel.\n"
]
},
{
"data": {
"text/plain": [
"inf"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.exp(1000)"
]
},
{
"cell_type": "markdown",
"id": "ea69745c",
"metadata": {},
"source": [
"所以需要归一化"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "6ab89eb9",
"metadata": {},
"outputs": [],
"source": [
"outputs = [[2.7,3.05,1000],\n",
" [2.5, 4.2, 3],\n",
" [5.0, 1.2, 2.1]]\n",
"outputs = np.array(outputs)\n",
"minus_outputs = outputs - np.max(outputs, axis=1, keepdims=True)\n",
"exp_outputs = np.exp(minus_outputs)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "3a8dc68f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([[0. , 0. , 1. ],\n",
" [0.12311225, 0.67390997, 0.20297778],\n",
" [0.92816556, 0.02076378, 0.05107066]]),\n",
" array([1., 1., 1.]))"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"minus_softmax = exp_outputs / np.sum(exp_outputs, axis=1, keepdims=True)\n",
"minus_softmax, np.sum(minus_softmax, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a5081448",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b223fc9",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}