{ "cells": [ { "cell_type": "markdown", "id": "3942358f", "metadata": {}, "source": [ "#
VGG模型复现" ] }, { "cell_type": "markdown", "id": "02aa5449", "metadata": {}, "source": [ "## 1.模型结构" ] }, { "cell_type": "markdown", "id": "4e698937", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "d782a69a", "metadata": {}, "source": [ "## 2. VGG模型" ] }, { "cell_type": "markdown", "id": "0fccd247", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "80d5d360", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "322d75dc", "metadata": {}, "source": [ "## 3.VGG代码" ] }, { "cell_type": "code", "execution_count": 1, "id": "c5308a19", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torchsummary import summary\n", "from torch.utils.data import DataLoader\n", "from torchvision.datasets import FashionMNIST" ] }, { "cell_type": "code", "execution_count": 2, "id": "031c2fac", "metadata": {}, "outputs": [], "source": [ "cfg = [(3, 3, 64), (3, 64, 64), 'M', (3, 64, 128), (3, 128, 128),\n", " 'M', (3, 128, 256), (3, 256, 256), (3, 256, 256), 'M',\n", " (3, 256, 512), (3, 512, 512), (3, 512, 512), 'M',(3, 512, 512),\n", " (3, 512, 512), (3, 512, 512), 'M'\n", " ]" ] }, { "cell_type": "code", "execution_count": 3, "id": "4f34b85b", "metadata": {}, "outputs": [], "source": [ "class BaseConv(nn.Module):\n", " def __init__(self, kernel_size, in_channels, out_channels):\n", " super().__init__()\n", " self.conv = nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=kernel_size//2)\n", " self.relu = nn.ReLU()\n", " def forward(self, x):\n", " return self.relu(self.conv(x))" ] }, { "cell_type": "code", "execution_count": 4, "id": "bd88cdf6", "metadata": {}, "outputs": [], "source": [ "class VGG(nn.Module):\n", " def __init__(self, num_classes):\n", " super().__init__()\n", " self.num_classes = num_classes\n", " self.cfg = cfg\n", " self.relu = nn.ReLU()\n", " self.fc1 = nn.Linear(7*7*512, 4096)\n", " self.fc2 = nn.Linear(4096, 4096)\n", " self.fc3 = nn.Linear(4096, num_classes)\n", " self.sequential = self.net()\n", " def net(self):\n", " sequential = []\n", " for c in self.cfg:\n", " if isinstance(c, tuple):\n", " sequential.append(BaseConv(c[0], c[1], c[2]))\n", " else:\n", " sequential.append(nn.MaxPool2d(kernel_size=2, stride=2))\n", " return nn.Sequential(*sequential)\n", " \n", " def forward(self, x):\n", " x = self.sequential(x)\n", " x = x.view(x.size(0), -1)\n", " x = self.fc1(x)\n", " x = self.relu(x)\n", " x = self.fc2(x)\n", " x = self.relu(x)\n", " x = self.fc3(x)\n", " x = self.relu(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 5, "id": "947ee66c", "metadata": {}, "outputs": [], "source": [ "vgg = VGG(10)" ] }, { "cell_type": "code", "execution_count": 7, "id": "9fe160fa", "metadata": {}, "outputs": [ { "ename": "RuntimeError", "evalue": "mat1 and mat2 shapes cannot be multiplied (2x41472 and 25088x4096)", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_16692\\316899526.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0msummary\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvgg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput_size\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m300\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m300\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;32mD:\\envs\\stark-lin\\lib\\site-packages\\torchsummary\\torchsummary.py\u001b[0m in \u001b[0;36msummary\u001b[1;34m(model, input_size, batch_size, device)\u001b[0m\n\u001b[0;32m 70\u001b[0m \u001b[1;31m# make a forward pass\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 71\u001b[0m \u001b[1;31m# print(x.shape)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 72\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 73\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 74\u001b[0m \u001b[1;31m# remove these hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mD:\\envs\\stark-lin\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_16692\\2003772644.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 21\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msequential\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 22\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 23\u001b[1;33m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 24\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 25\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfc2\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mD:\\envs\\stark-lin\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1210\u001b[0m \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbw_hook\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msetup_input_hook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1211\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1212\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1213\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1214\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0m_global_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mD:\\envs\\stark-lin\\lib\\site-packages\\torch\\nn\\modules\\linear.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 113\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 114\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 115\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 116\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (2x41472 and 25088x4096)" ] } ], "source": [ "summary(vgg, input_size=(3, 300, 300))" ] }, { "cell_type": "code", "execution_count": null, "id": "f96145cb", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 5 }