{ "cells": [ { "cell_type": "markdown", "id": "b76a974b", "metadata": {}, "source": [ "### 1. python code" ] }, { "cell_type": "markdown", "id": "d5615c35", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "70d78da6", "metadata": {}, "source": [ "$$S_{i,j} = \\frac{e^{z_{i,j}}}{\\sum_{l=1}^Le^{z_{i,j}}}$$" ] }, { "cell_type": "code", "execution_count": 1, "id": "9d932179", "metadata": {}, "outputs": [], "source": [ "import math" ] }, { "cell_type": "code", "execution_count": 2, "id": "14179b41", "metadata": {}, "outputs": [], "source": [ "outputs = [2.12, 3.14, -2]" ] }, { "cell_type": "code", "execution_count": 3, "id": "ace8866d", "metadata": {}, "outputs": [], "source": [ "e = math.e" ] }, { "cell_type": "code", "execution_count": 4, "id": "e0e8fcc5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[8.331137487687691, 23.10386685872218, 0.1353352832366127]\n" ] } ], "source": [ "exp_output = []\n", "for output in outputs:\n", " exp_output.append(e**output)\n", "print(exp_output)" ] }, { "cell_type": "code", "execution_count": 5, "id": "cb818c38", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.26389128483952834, 0.7318219293727912, 0.0042867857876804265] 0.9999999999999999\n" ] } ], "source": [ "base = sum(exp_output)\n", "values = []\n", "for value in exp_output:\n", " values.append(value / base)\n", "print(values, sum(values))" ] }, { "cell_type": "markdown", "id": "f4cb814b", "metadata": {}, "source": [ "### 2. numpy code" ] }, { "cell_type": "code", "execution_count": 6, "id": "6d2ec344", "metadata": {}, "outputs": [], "source": [ "import numpy as np" ] }, { "cell_type": "code", "execution_count": 7, "id": "10cc6ec3", "metadata": {}, "outputs": [], "source": [ "e = np.exp(outputs)" ] }, { "cell_type": "code", "execution_count": 22, "id": "1f43e475", "metadata": {}, "outputs": [], "source": [ "values = e / np.sum(e, keepdims=True)" ] }, { "cell_type": "code", "execution_count": 23, "id": "2c9d7f21", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([0.26389128, 0.73182193, 0.00428679]), 1.0)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "values, np.sum(values)" ] }, { "cell_type": "markdown", "id": "3c73bb74", "metadata": {}, "source": [ "### 3.batch size" ] }, { "cell_type": "code", "execution_count": 24, "id": "3e93a7e2", "metadata": {}, "outputs": [], "source": [ "outputs = [[2.7,3.05,4.2],\n", " [2.5, 4.2, 3],\n", " [5.0, 1.2, 2.1]]" ] }, { "cell_type": "code", "execution_count": 25, "id": "d3b150e5", "metadata": {}, "outputs": [], "source": [ "exp_outputs = np.exp(outputs)" ] }, { "cell_type": "code", "execution_count": 26, "id": "0f466cc2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 14.87973172, 21.11534442, 66.68633104],\n", " [ 12.18249396, 66.68633104, 20.08553692],\n", " [148.4131591 , 3.32011692, 8.16616991]])" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "exp_outputs" ] }, { "cell_type": "code", "execution_count": 27, "id": "3e3e9f8f", "metadata": {}, "outputs": [], "source": [ "bases = np.sum(exp_outputs, axis=1, keepdims=True)" ] }, { "cell_type": "code", "execution_count": 28, "id": "822c0d2d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[102.68140719],\n", " [ 98.95436192],\n", " [159.89944594]])" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bases" ] }, { "cell_type": "code", "execution_count": 29, "id": "89f8c6e3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([[0.14491165, 0.20563941, 0.64944894],\n", " [0.12311225, 0.67390997, 0.20297778],\n", " [0.92816556, 0.02076378, 0.05107066]]),\n", " array([1., 1., 1.]))" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "soft_outputs = exp_outputs / bases\n", "soft_outputs, np.sum(soft_outputs, axis=1)" ] }, { "cell_type": "markdown", "id": "d2f755cf", "metadata": {}, "source": [ "### 4.overflow prevention" ] }, { "cell_type": "markdown", "id": "c4356909", "metadata": {}, "source": [ "$$v = u - max(u)$$" ] }, { "cell_type": "markdown", "id": "958e2993", "metadata": {}, "source": [ "如果遇到值很大的情况。" ] }, { "cell_type": "code", "execution_count": 30, "id": "31d758fc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "D:\\envs\\stark-lin\\lib\\site-packages\\ipykernel_launcher.py:1: RuntimeWarning: overflow encountered in exp\n", " \"\"\"Entry point for launching an IPython kernel.\n" ] }, { "data": { "text/plain": [ "inf" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.exp(1000)" ] }, { "cell_type": "markdown", "id": "ea69745c", "metadata": {}, "source": [ "所以需要归一化" ] }, { "cell_type": "code", "execution_count": 33, "id": "6ab89eb9", "metadata": {}, "outputs": [], "source": [ "outputs = [[2.7,3.05,1000],\n", " [2.5, 4.2, 3],\n", " [5.0, 1.2, 2.1]]\n", "outputs = np.array(outputs)\n", "minus_outputs = outputs - np.max(outputs, axis=1, keepdims=True)\n", "exp_outputs = np.exp(minus_outputs)" ] }, { "cell_type": "code", "execution_count": 34, "id": "3a8dc68f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([[0. , 0. , 1. ],\n", " [0.12311225, 0.67390997, 0.20297778],\n", " [0.92816556, 0.02076378, 0.05107066]]),\n", " array([1., 1., 1.]))" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "minus_softmax = exp_outputs / np.sum(exp_outputs, axis=1, keepdims=True)\n", "minus_softmax, np.sum(minus_softmax, axis=1)" ] }, { "cell_type": "code", "execution_count": null, "id": "a5081448", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "1b223fc9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 5 }