You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

366 lines
14 KiB
Plaintext

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"id": "4de4e965",
"metadata": {},
"source": [
"# <center>MobileNet网络复现"
]
},
{
"cell_type": "markdown",
"id": "66f7878c",
"metadata": {},
"source": [
"## 1.MobileNet 网络结构"
]
},
{
"cell_type": "markdown",
"id": "8dffc653",
"metadata": {},
"source": [
"<img src=\"arch.png\" width=40% align=center>"
]
},
{
"cell_type": "markdown",
"id": "6dc0f823",
"metadata": {},
"source": [
"## 2.Depthwise结构"
]
},
{
"cell_type": "markdown",
"id": "62377119",
"metadata": {},
"source": [
"<img src=\"dw.png\" width=40% align=center>"
]
},
{
"cell_type": "markdown",
"id": "44bf2b20",
"metadata": {},
"source": [
"## 3.代码编写"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "682e0150",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torchsummary import summary\n",
"from torchvision.datasets import FashionMNIST\n",
"from torch.utils.data import DataLoader"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "23b01841",
"metadata": {},
"outputs": [],
"source": [
"#基础卷积\n",
"class base_conv(nn.Module):\n",
" def __init__(self, kernel_size, in_channels, out_channels, stride):\n",
" super().__init__()\n",
" self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size//2, stride=stride, bias=False)\n",
" self.bn = nn.BatchNorm2d(out_channels)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" \n",
" def forward(self, x):\n",
" x = self.conv(x)\n",
" x = self.bn(x)\n",
" x = self.relu(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "93d27a1b",
"metadata": {},
"source": [
"dw结构将groups设为in_channels同时out_channels也设为与in_channels相同。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "c8cb4d70",
"metadata": {},
"outputs": [],
"source": [
"class depth_wise(nn.Module):\n",
" def __init__(self, kernel_size, in_channels, out_channels, stride):\n",
" super().__init__()\n",
" self.conv3x3 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, padding=kernel_size//2, groups=in_channels, stride=stride, bias=False)\n",
" self.bn1 = nn.BatchNorm2d(in_channels)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.conv1x1 = base_conv(1, in_channels, out_channels, 1)\n",
" self.bn2 = nn.BatchNorm2d(out_channels)\n",
" \n",
" def forward(self, x):\n",
" x = self.conv3x3(x)\n",
" x = self.bn1(x)\n",
" x = self.relu(x)\n",
" x = self.conv1x1(x)\n",
" x = self.bn2(x)\n",
" x = self.relu(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "87933337",
"metadata": {},
"source": [
"每层卷积参数"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "21ef45cf",
"metadata": {},
"outputs": [],
"source": [
"model_cfg = [[3, 32, 2], (3, 64, 1),(3, 128, 2), (3, 128, 1), (3, 256, 2),(3, 256, 1),\n",
" (3, 512, 2), (3, 512, 1, 5), (3, 1024, 2),(3, 1024, 2)]#列表是第一个卷积神经网络,元组是卷积神经网络(卷积核大小,输出通道,步长,重复次数)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "cd3c5501",
"metadata": {},
"outputs": [],
"source": [
"class MobileNetV1(nn.Module):\n",
" def __init__(self, in_channels, num_classes, model_cfg):\n",
" super().__init__()\n",
" self.model_cfg = model_cfg\n",
" self.in_channels = in_channels\n",
" self.avg = nn.AdaptiveAvgPool2d(1)\n",
" self.fc = nn.Linear(1024, num_classes)\n",
" self.softmax = nn.Softmax(dim=1)#在第一个维度上进行softmax第0个维度是batch_size\n",
" self.sequential = self.net()\n",
" def net(self):\n",
" in_channels = self.in_channels\n",
" sequential = []\n",
" for cfg in self.model_cfg:\n",
" if isinstance(cfg, list):#如果是第一个\n",
" sequential.append(base_conv(cfg[0], in_channels, cfg[1], cfg[2]))\n",
" in_channels = cfg[1]\n",
" elif len(cfg) != 5:#如果不是重复块\n",
" sequential.append(depth_wise(cfg[0], in_channels, cfg[1], cfg[2]))\n",
" in_channels = cfg[1]\n",
" else:#重复块\n",
" for _ in range(cfg[-1]):\n",
" sequential.append(depth_wise(cfg[0], in_channels, cfg[1], cfg[2]))\n",
" in_channels = cfg[1]\n",
" return nn.Sequential(*sequential)\n",
" \n",
" def forward(self, x):\n",
" x = self.sequential(x)\n",
" x = self.avg(x)\n",
" x = x.view(x.size(0), -1)\n",
" x = self.fc(x)\n",
" x = self.softmax(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e9627dc0",
"metadata": {},
"outputs": [],
"source": [
"model = MobileNetV1(3, 10, model_cfg)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "fddc919e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([2, 10])\n",
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 32, 112, 112] 864\n",
" BatchNorm2d-2 [-1, 32, 112, 112] 64\n",
" ReLU-3 [-1, 32, 112, 112] 0\n",
" base_conv-4 [-1, 32, 112, 112] 0\n",
" Conv2d-5 [-1, 32, 112, 112] 288\n",
" BatchNorm2d-6 [-1, 32, 112, 112] 64\n",
" ReLU-7 [-1, 32, 112, 112] 0\n",
" Conv2d-8 [-1, 64, 112, 112] 2,048\n",
" BatchNorm2d-9 [-1, 64, 112, 112] 128\n",
" ReLU-10 [-1, 64, 112, 112] 0\n",
" base_conv-11 [-1, 64, 112, 112] 0\n",
" BatchNorm2d-12 [-1, 64, 112, 112] 128\n",
" ReLU-13 [-1, 64, 112, 112] 0\n",
" depth_wise-14 [-1, 64, 112, 112] 0\n",
" Conv2d-15 [-1, 64, 56, 56] 576\n",
" BatchNorm2d-16 [-1, 64, 56, 56] 128\n",
" ReLU-17 [-1, 64, 56, 56] 0\n",
" Conv2d-18 [-1, 128, 56, 56] 8,192\n",
" BatchNorm2d-19 [-1, 128, 56, 56] 256\n",
" ReLU-20 [-1, 128, 56, 56] 0\n",
" base_conv-21 [-1, 128, 56, 56] 0\n",
" BatchNorm2d-22 [-1, 128, 56, 56] 256\n",
" ReLU-23 [-1, 128, 56, 56] 0\n",
" depth_wise-24 [-1, 128, 56, 56] 0\n",
" Conv2d-25 [-1, 128, 56, 56] 1,152\n",
" BatchNorm2d-26 [-1, 128, 56, 56] 256\n",
" ReLU-27 [-1, 128, 56, 56] 0\n",
" Conv2d-28 [-1, 128, 56, 56] 16,384\n",
" BatchNorm2d-29 [-1, 128, 56, 56] 256\n",
" ReLU-30 [-1, 128, 56, 56] 0\n",
" base_conv-31 [-1, 128, 56, 56] 0\n",
" BatchNorm2d-32 [-1, 128, 56, 56] 256\n",
" ReLU-33 [-1, 128, 56, 56] 0\n",
" depth_wise-34 [-1, 128, 56, 56] 0\n",
" Conv2d-35 [-1, 128, 28, 28] 1,152\n",
" BatchNorm2d-36 [-1, 128, 28, 28] 256\n",
" ReLU-37 [-1, 128, 28, 28] 0\n",
" Conv2d-38 [-1, 256, 28, 28] 32,768\n",
" BatchNorm2d-39 [-1, 256, 28, 28] 512\n",
" ReLU-40 [-1, 256, 28, 28] 0\n",
" base_conv-41 [-1, 256, 28, 28] 0\n",
" BatchNorm2d-42 [-1, 256, 28, 28] 512\n",
" ReLU-43 [-1, 256, 28, 28] 0\n",
" depth_wise-44 [-1, 256, 28, 28] 0\n",
" Conv2d-45 [-1, 256, 28, 28] 2,304\n",
" BatchNorm2d-46 [-1, 256, 28, 28] 512\n",
" ReLU-47 [-1, 256, 28, 28] 0\n",
" Conv2d-48 [-1, 256, 28, 28] 65,536\n",
" BatchNorm2d-49 [-1, 256, 28, 28] 512\n",
" ReLU-50 [-1, 256, 28, 28] 0\n",
" base_conv-51 [-1, 256, 28, 28] 0\n",
" BatchNorm2d-52 [-1, 256, 28, 28] 512\n",
" ReLU-53 [-1, 256, 28, 28] 0\n",
" depth_wise-54 [-1, 256, 28, 28] 0\n",
" Conv2d-55 [-1, 256, 14, 14] 2,304\n",
" BatchNorm2d-56 [-1, 256, 14, 14] 512\n",
" ReLU-57 [-1, 256, 14, 14] 0\n",
" Conv2d-58 [-1, 512, 14, 14] 131,072\n",
" BatchNorm2d-59 [-1, 512, 14, 14] 1,024\n",
" ReLU-60 [-1, 512, 14, 14] 0\n",
" base_conv-61 [-1, 512, 14, 14] 0\n",
" BatchNorm2d-62 [-1, 512, 14, 14] 1,024\n",
" ReLU-63 [-1, 512, 14, 14] 0\n",
" depth_wise-64 [-1, 512, 14, 14] 0\n",
" Conv2d-65 [-1, 512, 14, 14] 4,608\n",
" BatchNorm2d-66 [-1, 512, 14, 14] 1,024\n",
" ReLU-67 [-1, 512, 14, 14] 0\n",
" Conv2d-68 [-1, 512, 14, 14] 262,144\n",
" BatchNorm2d-69 [-1, 512, 14, 14] 1,024\n",
" ReLU-70 [-1, 512, 14, 14] 0\n",
" base_conv-71 [-1, 512, 14, 14] 0\n",
" BatchNorm2d-72 [-1, 512, 14, 14] 1,024\n",
" ReLU-73 [-1, 512, 14, 14] 0\n",
" depth_wise-74 [-1, 512, 14, 14] 0\n",
" Conv2d-75 [-1, 512, 7, 7] 4,608\n",
" BatchNorm2d-76 [-1, 512, 7, 7] 1,024\n",
" ReLU-77 [-1, 512, 7, 7] 0\n",
" Conv2d-78 [-1, 1024, 7, 7] 524,288\n",
" BatchNorm2d-79 [-1, 1024, 7, 7] 2,048\n",
" ReLU-80 [-1, 1024, 7, 7] 0\n",
" base_conv-81 [-1, 1024, 7, 7] 0\n",
" BatchNorm2d-82 [-1, 1024, 7, 7] 2,048\n",
" ReLU-83 [-1, 1024, 7, 7] 0\n",
" depth_wise-84 [-1, 1024, 7, 7] 0\n",
" Conv2d-85 [-1, 1024, 4, 4] 9,216\n",
" BatchNorm2d-86 [-1, 1024, 4, 4] 2,048\n",
" ReLU-87 [-1, 1024, 4, 4] 0\n",
" Conv2d-88 [-1, 1024, 4, 4] 1,048,576\n",
" BatchNorm2d-89 [-1, 1024, 4, 4] 2,048\n",
" ReLU-90 [-1, 1024, 4, 4] 0\n",
" base_conv-91 [-1, 1024, 4, 4] 0\n",
" BatchNorm2d-92 [-1, 1024, 4, 4] 2,048\n",
" ReLU-93 [-1, 1024, 4, 4] 0\n",
" depth_wise-94 [-1, 1024, 4, 4] 0\n",
"AdaptiveAvgPool2d-95 [-1, 1024, 1, 1] 0\n",
" Linear-96 [-1, 10] 10,250\n",
" Softmax-97 [-1, 10] 0\n",
"================================================================\n",
"Total params: 2,149,834\n",
"Trainable params: 2,149,834\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.57\n",
"Forward/backward pass size (MB): 167.97\n",
"Params size (MB): 8.20\n",
"Estimated Total Size (MB): 176.75\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"summary(model, input_size=(3, 224, 224))"
]
},
{
"cell_type": "markdown",
"id": "c47e90e3",
"metadata": {},
"source": [
"总结:感觉论文中给的模型结果表有误,修改位置如下。"
]
},
{
"cell_type": "markdown",
"id": "71074fab",
"metadata": {},
"source": [
"<img src=\"modify.png\" width=40% align=center>"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "720904d6",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}