You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
262 lines
8.8 KiB
Plaintext
262 lines
8.8 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3942358f",
|
|
"metadata": {},
|
|
"source": [
|
|
"# <center>VGG模型复现"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "02aa5449",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 1.模型结构"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4e698937",
|
|
"metadata": {},
|
|
"source": [
|
|
"<img src=\"vgg.png\" width=40% align=center>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d782a69a",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 2. VGG模型"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "0fccd247",
|
|
"metadata": {},
|
|
"source": [
|
|
"<img src=\"vgg_arch.png\" width=60% align=center>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "80d5d360",
|
|
"metadata": {},
|
|
"source": [
|
|
"<img src=\"vgg_detail.png\" width=60% align=center>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "322d75dc",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 3.VGG代码"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "c5308a19",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"from torchsummary import summary\n",
|
|
"from torch.utils.data import DataLoader\n",
|
|
"from torchvision.datasets import FashionMNIST"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "031c2fac",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"cfg = [(3, 3, 64), (3, 64, 64), 'M', (3, 64, 128), (3, 128, 128),\n",
|
|
" 'M', (3, 128, 256), (3, 256, 256), (3, 256, 256), 'M',\n",
|
|
" (3, 256, 512), (3, 512, 512), (3, 512, 512), 'M',(3, 512, 512),\n",
|
|
" (3, 512, 512), (3, 512, 512), 'M'\n",
|
|
" ]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "4f34b85b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class BaseConv(nn.Module):\n",
|
|
" def __init__(self, kernel_size, in_channels, out_channels):\n",
|
|
" super().__init__()\n",
|
|
" self.conv = nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=kernel_size//2)\n",
|
|
" self.relu = nn.ReLU()\n",
|
|
" def forward(self, x):\n",
|
|
" return self.relu(self.conv(x))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "bd88cdf6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class VGG(nn.Module):\n",
|
|
" def __init__(self, num_classes):\n",
|
|
" super().__init__()\n",
|
|
" self.num_classes = num_classes\n",
|
|
" self.cfg = cfg\n",
|
|
" self.relu = nn.ReLU()\n",
|
|
" self.fc1 = nn.Linear(7*7*512, 4096)\n",
|
|
" self.fc2 = nn.Linear(4096, 4096)\n",
|
|
" self.fc3 = nn.Linear(4096, num_classes)\n",
|
|
" self.sequential = self.net()\n",
|
|
" def net(self):\n",
|
|
" sequential = []\n",
|
|
" for c in self.cfg:\n",
|
|
" if isinstance(c, tuple):\n",
|
|
" sequential.append(BaseConv(c[0], c[1], c[2]))\n",
|
|
" else:\n",
|
|
" sequential.append(nn.MaxPool2d(kernel_size=2, stride=2))\n",
|
|
" return nn.Sequential(*sequential)\n",
|
|
" \n",
|
|
" def forward(self, x):\n",
|
|
" x = self.sequential(x)\n",
|
|
" x = x.view(x.size(0), -1)\n",
|
|
" x = self.fc1(x)\n",
|
|
" x = self.relu(x)\n",
|
|
" x = self.fc2(x)\n",
|
|
" x = self.relu(x)\n",
|
|
" x = self.fc3(x)\n",
|
|
" x = self.relu(x)\n",
|
|
" return x"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "947ee66c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"vgg = VGG(10)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "9fe160fa",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"----------------------------------------------------------------\n",
|
|
" Layer (type) Output Shape Param #\n",
|
|
"================================================================\n",
|
|
" Conv2d-1 [-1, 64, 224, 224] 1,792\n",
|
|
" ReLU-2 [-1, 64, 224, 224] 0\n",
|
|
" BaseConv-3 [-1, 64, 224, 224] 0\n",
|
|
" Conv2d-4 [-1, 64, 224, 224] 36,928\n",
|
|
" ReLU-5 [-1, 64, 224, 224] 0\n",
|
|
" BaseConv-6 [-1, 64, 224, 224] 0\n",
|
|
" MaxPool2d-7 [-1, 64, 112, 112] 0\n",
|
|
" Conv2d-8 [-1, 128, 112, 112] 73,856\n",
|
|
" ReLU-9 [-1, 128, 112, 112] 0\n",
|
|
" BaseConv-10 [-1, 128, 112, 112] 0\n",
|
|
" Conv2d-11 [-1, 128, 112, 112] 147,584\n",
|
|
" ReLU-12 [-1, 128, 112, 112] 0\n",
|
|
" BaseConv-13 [-1, 128, 112, 112] 0\n",
|
|
" MaxPool2d-14 [-1, 128, 56, 56] 0\n",
|
|
" Conv2d-15 [-1, 256, 56, 56] 295,168\n",
|
|
" ReLU-16 [-1, 256, 56, 56] 0\n",
|
|
" BaseConv-17 [-1, 256, 56, 56] 0\n",
|
|
" Conv2d-18 [-1, 256, 56, 56] 590,080\n",
|
|
" ReLU-19 [-1, 256, 56, 56] 0\n",
|
|
" BaseConv-20 [-1, 256, 56, 56] 0\n",
|
|
" Conv2d-21 [-1, 256, 56, 56] 590,080\n",
|
|
" ReLU-22 [-1, 256, 56, 56] 0\n",
|
|
" BaseConv-23 [-1, 256, 56, 56] 0\n",
|
|
" MaxPool2d-24 [-1, 256, 28, 28] 0\n",
|
|
" Conv2d-25 [-1, 512, 28, 28] 1,180,160\n",
|
|
" ReLU-26 [-1, 512, 28, 28] 0\n",
|
|
" BaseConv-27 [-1, 512, 28, 28] 0\n",
|
|
" Conv2d-28 [-1, 512, 28, 28] 2,359,808\n",
|
|
" ReLU-29 [-1, 512, 28, 28] 0\n",
|
|
" BaseConv-30 [-1, 512, 28, 28] 0\n",
|
|
" Conv2d-31 [-1, 512, 28, 28] 2,359,808\n",
|
|
" ReLU-32 [-1, 512, 28, 28] 0\n",
|
|
" BaseConv-33 [-1, 512, 28, 28] 0\n",
|
|
" MaxPool2d-34 [-1, 512, 14, 14] 0\n",
|
|
" Conv2d-35 [-1, 512, 14, 14] 2,359,808\n",
|
|
" ReLU-36 [-1, 512, 14, 14] 0\n",
|
|
" BaseConv-37 [-1, 512, 14, 14] 0\n",
|
|
" Conv2d-38 [-1, 512, 14, 14] 2,359,808\n",
|
|
" ReLU-39 [-1, 512, 14, 14] 0\n",
|
|
" BaseConv-40 [-1, 512, 14, 14] 0\n",
|
|
" Conv2d-41 [-1, 512, 14, 14] 2,359,808\n",
|
|
" ReLU-42 [-1, 512, 14, 14] 0\n",
|
|
" BaseConv-43 [-1, 512, 14, 14] 0\n",
|
|
" MaxPool2d-44 [-1, 512, 7, 7] 0\n",
|
|
" Linear-45 [-1, 4096] 102,764,544\n",
|
|
" ReLU-46 [-1, 4096] 0\n",
|
|
" Linear-47 [-1, 4096] 16,781,312\n",
|
|
" ReLU-48 [-1, 4096] 0\n",
|
|
" Linear-49 [-1, 10] 40,970\n",
|
|
" ReLU-50 [-1, 10] 0\n",
|
|
"================================================================\n",
|
|
"Total params: 134,301,514\n",
|
|
"Trainable params: 134,301,514\n",
|
|
"Non-trainable params: 0\n",
|
|
"----------------------------------------------------------------\n",
|
|
"Input size (MB): 0.57\n",
|
|
"Forward/backward pass size (MB): 321.88\n",
|
|
"Params size (MB): 512.32\n",
|
|
"Estimated Total Size (MB): 834.77\n",
|
|
"----------------------------------------------------------------\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"summary(vgg, input_size=(3, 224, 224))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f96145cb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|