You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

209 lines
12 KiB
Plaintext

6 months ago
{
"cells": [
{
"cell_type": "markdown",
"id": "3942358f",
"metadata": {},
"source": [
"# <center>VGG模型复现"
]
},
{
"cell_type": "markdown",
"id": "02aa5449",
"metadata": {},
"source": [
"## 1.模型结构"
]
},
{
"cell_type": "markdown",
"id": "4e698937",
"metadata": {},
"source": [
"<img src=\"vgg.png\" width=40% align=center>"
]
},
{
"cell_type": "markdown",
"id": "d782a69a",
"metadata": {},
"source": [
"## 2. VGG模型"
]
},
{
"cell_type": "markdown",
"id": "0fccd247",
"metadata": {},
"source": [
"<img src=\"vgg_arch.png\" width=60% align=center>"
]
},
{
"cell_type": "markdown",
"id": "80d5d360",
"metadata": {},
"source": [
"<img src=\"vgg_detail.png\" width=60% align=center>"
]
},
{
"cell_type": "markdown",
"id": "322d75dc",
"metadata": {},
"source": [
"## 3.VGG代码"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c5308a19",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torchsummary import summary\n",
"from torch.utils.data import DataLoader\n",
"from torchvision.datasets import FashionMNIST"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "031c2fac",
"metadata": {},
"outputs": [],
"source": [
"cfg = [(3, 3, 64), (3, 64, 64), 'M', (3, 64, 128), (3, 128, 128),\n",
" 'M', (3, 128, 256), (3, 256, 256), (3, 256, 256), 'M',\n",
" (3, 256, 512), (3, 512, 512), (3, 512, 512), 'M',(3, 512, 512),\n",
" (3, 512, 512), (3, 512, 512), 'M'\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4f34b85b",
"metadata": {},
"outputs": [],
"source": [
"class BaseConv(nn.Module):\n",
" def __init__(self, kernel_size, in_channels, out_channels):\n",
" super().__init__()\n",
" self.conv = nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=kernel_size//2)\n",
" self.relu = nn.ReLU()\n",
" def forward(self, x):\n",
" return self.relu(self.conv(x))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bd88cdf6",
"metadata": {},
"outputs": [],
"source": [
"class VGG(nn.Module):\n",
" def __init__(self, num_classes):\n",
" super().__init__()\n",
" self.num_classes = num_classes\n",
" self.cfg = cfg\n",
" self.relu = nn.ReLU()\n",
" self.fc1 = nn.Linear(7*7*512, 4096)\n",
" self.fc2 = nn.Linear(4096, 4096)\n",
" self.fc3 = nn.Linear(4096, num_classes)\n",
" self.sequential = self.net()\n",
" def net(self):\n",
" sequential = []\n",
" for c in self.cfg:\n",
" if isinstance(c, tuple):\n",
" sequential.append(BaseConv(c[0], c[1], c[2]))\n",
" else:\n",
" sequential.append(nn.MaxPool2d(kernel_size=2, stride=2))\n",
" return nn.Sequential(*sequential)\n",
" \n",
" def forward(self, x):\n",
" x = self.sequential(x)\n",
" x = x.view(x.size(0), -1)\n",
" x = self.fc1(x)\n",
" x = self.relu(x)\n",
" x = self.fc2(x)\n",
" x = self.relu(x)\n",
" x = self.fc3(x)\n",
" x = self.relu(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "947ee66c",
"metadata": {},
"outputs": [],
"source": [
"vgg = VGG(10)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9fe160fa",
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "mat1 and mat2 shapes cannot be multiplied (2x41472 and 25088x4096)",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_16692\\316899526.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0msummary\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvgg\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput_size\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m3\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m300\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m300\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[1;32mD:\\envs\\stark-lin\\lib\\site-packages\\torchsummary\\torchsummary.py\u001b[0m in \u001b[0;36msummary\u001b[1;34m(model, input_size, batch_size, device)\u001b[0m\n\u001b[0;32m 70\u001b[0m \u001b[1;31m# make a forward pass\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 71\u001b[0m \u001b[1;31m# print(x.shape)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 72\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 73\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 74\u001b[0m \u001b[1;31m# remove these hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mD:\\envs\\stark-lin\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1192\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[0;32m 1193\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1194\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1195\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1196\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_16692\\2003772644.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 21\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msequential\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 22\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msize\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 23\u001b[1;33m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 24\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 25\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfc2\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mD:\\envs\\stark-lin\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *input, **kwargs)\u001b[0m\n\u001b[0;32m 1210\u001b[0m \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbw_hook\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msetup_input_hook\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1211\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1212\u001b[1;33m \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1213\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1214\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[1;32min\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0m_global_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32mD:\\envs\\stark-lin\\lib\\site-packages\\torch\\nn\\modules\\linear.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 113\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 114\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 115\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 116\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (2x41472 and 25088x4096)"
]
}
],
"source": [
"summary(vgg, input_size=(3, 300, 300))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f96145cb",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}