You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

262 lines
8.8 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "3942358f",
"metadata": {},
"source": [
"# <center>VGG模型复现"
]
},
{
"cell_type": "markdown",
"id": "02aa5449",
"metadata": {},
"source": [
"## 1.模型结构"
]
},
{
"cell_type": "markdown",
"id": "4e698937",
"metadata": {},
"source": [
"<img src=\"vgg.png\" width=40% align=center>"
]
},
{
"cell_type": "markdown",
"id": "d782a69a",
"metadata": {},
"source": [
"## 2. VGG模型"
]
},
{
"cell_type": "markdown",
"id": "0fccd247",
"metadata": {},
"source": [
"<img src=\"vgg_arch.png\" width=60% align=center>"
]
},
{
"cell_type": "markdown",
"id": "80d5d360",
"metadata": {},
"source": [
"<img src=\"vgg_detail.png\" width=60% align=center>"
]
},
{
"cell_type": "markdown",
"id": "322d75dc",
"metadata": {},
"source": [
"## 3.VGG代码"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c5308a19",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torchsummary import summary\n",
"from torch.utils.data import DataLoader\n",
"from torchvision.datasets import FashionMNIST"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "031c2fac",
"metadata": {},
"outputs": [],
"source": [
"cfg = [(3, 3, 64), (3, 64, 64), 'M', (3, 64, 128), (3, 128, 128),\n",
" 'M', (3, 128, 256), (3, 256, 256), (3, 256, 256), 'M',\n",
" (3, 256, 512), (3, 512, 512), (3, 512, 512), 'M',(3, 512, 512),\n",
" (3, 512, 512), (3, 512, 512), 'M'\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4f34b85b",
"metadata": {},
"outputs": [],
"source": [
"class BaseConv(nn.Module):\n",
" def __init__(self, kernel_size, in_channels, out_channels):\n",
" super().__init__()\n",
" self.conv = nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=kernel_size//2)\n",
" self.relu = nn.ReLU()\n",
" def forward(self, x):\n",
" return self.relu(self.conv(x))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "bd88cdf6",
"metadata": {},
"outputs": [],
"source": [
"class VGG(nn.Module):\n",
" def __init__(self, num_classes):\n",
" super().__init__()\n",
" self.num_classes = num_classes\n",
" self.cfg = cfg\n",
" self.relu = nn.ReLU()\n",
" self.fc1 = nn.Linear(7*7*512, 4096)\n",
" self.fc2 = nn.Linear(4096, 4096)\n",
" self.fc3 = nn.Linear(4096, num_classes)\n",
" self.sequential = self.net()\n",
" def net(self):\n",
" sequential = []\n",
" for c in self.cfg:\n",
" if isinstance(c, tuple):\n",
" sequential.append(BaseConv(c[0], c[1], c[2]))\n",
" else:\n",
" sequential.append(nn.MaxPool2d(kernel_size=2, stride=2))\n",
" return nn.Sequential(*sequential)\n",
" \n",
" def forward(self, x):\n",
" x = self.sequential(x)\n",
" x = x.view(x.size(0), -1)\n",
" x = self.fc1(x)\n",
" x = self.relu(x)\n",
" x = self.fc2(x)\n",
" x = self.relu(x)\n",
" x = self.fc3(x)\n",
" x = self.relu(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "947ee66c",
"metadata": {},
"outputs": [],
"source": [
"vgg = VGG(10)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9fe160fa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 224, 224] 1,792\n",
" ReLU-2 [-1, 64, 224, 224] 0\n",
" BaseConv-3 [-1, 64, 224, 224] 0\n",
" Conv2d-4 [-1, 64, 224, 224] 36,928\n",
" ReLU-5 [-1, 64, 224, 224] 0\n",
" BaseConv-6 [-1, 64, 224, 224] 0\n",
" MaxPool2d-7 [-1, 64, 112, 112] 0\n",
" Conv2d-8 [-1, 128, 112, 112] 73,856\n",
" ReLU-9 [-1, 128, 112, 112] 0\n",
" BaseConv-10 [-1, 128, 112, 112] 0\n",
" Conv2d-11 [-1, 128, 112, 112] 147,584\n",
" ReLU-12 [-1, 128, 112, 112] 0\n",
" BaseConv-13 [-1, 128, 112, 112] 0\n",
" MaxPool2d-14 [-1, 128, 56, 56] 0\n",
" Conv2d-15 [-1, 256, 56, 56] 295,168\n",
" ReLU-16 [-1, 256, 56, 56] 0\n",
" BaseConv-17 [-1, 256, 56, 56] 0\n",
" Conv2d-18 [-1, 256, 56, 56] 590,080\n",
" ReLU-19 [-1, 256, 56, 56] 0\n",
" BaseConv-20 [-1, 256, 56, 56] 0\n",
" Conv2d-21 [-1, 256, 56, 56] 590,080\n",
" ReLU-22 [-1, 256, 56, 56] 0\n",
" BaseConv-23 [-1, 256, 56, 56] 0\n",
" MaxPool2d-24 [-1, 256, 28, 28] 0\n",
" Conv2d-25 [-1, 512, 28, 28] 1,180,160\n",
" ReLU-26 [-1, 512, 28, 28] 0\n",
" BaseConv-27 [-1, 512, 28, 28] 0\n",
" Conv2d-28 [-1, 512, 28, 28] 2,359,808\n",
" ReLU-29 [-1, 512, 28, 28] 0\n",
" BaseConv-30 [-1, 512, 28, 28] 0\n",
" Conv2d-31 [-1, 512, 28, 28] 2,359,808\n",
" ReLU-32 [-1, 512, 28, 28] 0\n",
" BaseConv-33 [-1, 512, 28, 28] 0\n",
" MaxPool2d-34 [-1, 512, 14, 14] 0\n",
" Conv2d-35 [-1, 512, 14, 14] 2,359,808\n",
" ReLU-36 [-1, 512, 14, 14] 0\n",
" BaseConv-37 [-1, 512, 14, 14] 0\n",
" Conv2d-38 [-1, 512, 14, 14] 2,359,808\n",
" ReLU-39 [-1, 512, 14, 14] 0\n",
" BaseConv-40 [-1, 512, 14, 14] 0\n",
" Conv2d-41 [-1, 512, 14, 14] 2,359,808\n",
" ReLU-42 [-1, 512, 14, 14] 0\n",
" BaseConv-43 [-1, 512, 14, 14] 0\n",
" MaxPool2d-44 [-1, 512, 7, 7] 0\n",
" Linear-45 [-1, 4096] 102,764,544\n",
" ReLU-46 [-1, 4096] 0\n",
" Linear-47 [-1, 4096] 16,781,312\n",
" ReLU-48 [-1, 4096] 0\n",
" Linear-49 [-1, 10] 40,970\n",
" ReLU-50 [-1, 10] 0\n",
"================================================================\n",
"Total params: 134,301,514\n",
"Trainable params: 134,301,514\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.57\n",
"Forward/backward pass size (MB): 321.88\n",
"Params size (MB): 512.32\n",
"Estimated Total Size (MB): 834.77\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"summary(vgg, input_size=(3, 224, 224))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f96145cb",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}