{ "cells": [ { "cell_type": "markdown", "id": "4de4e965", "metadata": {}, "source": [ "#
MobileNet网络复现" ] }, { "cell_type": "markdown", "id": "66f7878c", "metadata": {}, "source": [ "## 1.MobileNet 网络结构" ] }, { "cell_type": "markdown", "id": "8dffc653", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "6dc0f823", "metadata": {}, "source": [ "## 2.Depthwise结构" ] }, { "cell_type": "markdown", "id": "62377119", "metadata": {}, "source": [ "" ] }, { "cell_type": "markdown", "id": "44bf2b20", "metadata": {}, "source": [ "## 3.代码编写" ] }, { "cell_type": "code", "execution_count": 1, "id": "682e0150", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torchsummary import summary\n", "from torchvision.datasets import FashionMNIST\n", "from torch.utils.data import DataLoader" ] }, { "cell_type": "code", "execution_count": 2, "id": "23b01841", "metadata": {}, "outputs": [], "source": [ "#基础卷积\n", "class base_conv(nn.Module):\n", " def __init__(self, kernel_size, in_channels, out_channels, stride):\n", " super().__init__()\n", " self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size//2, stride=stride, bias=False)\n", " self.bn = nn.BatchNorm2d(out_channels)\n", " self.relu = nn.ReLU(inplace=True)\n", " \n", " def forward(self, x):\n", " x = self.conv(x)\n", " x = self.bn(x)\n", " x = self.relu(x)\n", " return x" ] }, { "cell_type": "markdown", "id": "93d27a1b", "metadata": {}, "source": [ "dw结构将groups设为in_channels,同时out_channels也设为与in_channels相同。" ] }, { "cell_type": "code", "execution_count": 3, "id": "c8cb4d70", "metadata": {}, "outputs": [], "source": [ "class depth_wise(nn.Module):\n", " def __init__(self, kernel_size, in_channels, out_channels, stride):\n", " super().__init__()\n", " self.conv3x3 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, padding=kernel_size//2, groups=in_channels, stride=stride, bias=False)\n", " self.bn1 = nn.BatchNorm2d(in_channels)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.conv1x1 = base_conv(1, in_channels, out_channels, 1)\n", " self.bn2 = nn.BatchNorm2d(out_channels)\n", " \n", " def forward(self, x):\n", " x = self.conv3x3(x)\n", " x = self.bn1(x)\n", " x = self.relu(x)\n", " x = self.conv1x1(x)\n", " x = self.bn2(x)\n", " x = self.relu(x)\n", " return x" ] }, { "cell_type": "markdown", "id": "87933337", "metadata": {}, "source": [ "每层卷积参数" ] }, { "cell_type": "code", "execution_count": 4, "id": "21ef45cf", "metadata": {}, "outputs": [], "source": [ "model_cfg = [[3, 32, 2], (3, 64, 1),(3, 128, 2), (3, 128, 1), (3, 256, 2),(3, 256, 1),\n", " (3, 512, 2), (3, 512, 1, 5), (3, 1024, 2),(3, 1024, 2)]#列表是第一个卷积神经网络,元组是卷积神经网络(卷积核大小,输出通道,步长,重复次数)" ] }, { "cell_type": "code", "execution_count": 11, "id": "cd3c5501", "metadata": {}, "outputs": [], "source": [ "class MobileNetV1(nn.Module):\n", " def __init__(self, in_channels, num_classes, model_cfg):\n", " super().__init__()\n", " self.model_cfg = model_cfg\n", " self.in_channels = in_channels\n", " self.avg = nn.AdaptiveAvgPool2d(1)\n", " self.fc = nn.Linear(1024, num_classes)\n", " self.softmax = nn.Softmax(dim=1)#在第一个维度上进行softmax,第0个维度是batch_size\n", " self.sequential = self.net()\n", " def net(self):\n", " in_channels = self.in_channels\n", " sequential = []\n", " for cfg in self.model_cfg:\n", " if isinstance(cfg, list):#如果是第一个\n", " sequential.append(base_conv(cfg[0], in_channels, cfg[1], cfg[2]))\n", " in_channels = cfg[1]\n", " elif len(cfg) != 5:#如果不是重复块\n", " sequential.append(depth_wise(cfg[0], in_channels, cfg[1], cfg[2]))\n", " in_channels = cfg[1]\n", " else:#重复块\n", " for _ in range(cfg[-1]):\n", " sequential.append(depth_wise(cfg[0], in_channels, cfg[1], cfg[2]))\n", " in_channels = cfg[1]\n", " return nn.Sequential(*sequential)\n", " \n", " def forward(self, x):\n", " x = self.sequential(x)\n", " x = self.avg(x)\n", " x = x.view(x.size(0), -1)\n", " x = self.fc(x)\n", " x = self.softmax(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 12, "id": "e9627dc0", "metadata": {}, "outputs": [], "source": [ "model = MobileNetV1(3, 10, model_cfg)" ] }, { "cell_type": "code", "execution_count": 13, "id": "fddc919e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([2, 10])\n", "----------------------------------------------------------------\n", " Layer (type) Output Shape Param #\n", "================================================================\n", " Conv2d-1 [-1, 32, 112, 112] 864\n", " BatchNorm2d-2 [-1, 32, 112, 112] 64\n", " ReLU-3 [-1, 32, 112, 112] 0\n", " base_conv-4 [-1, 32, 112, 112] 0\n", " Conv2d-5 [-1, 32, 112, 112] 288\n", " BatchNorm2d-6 [-1, 32, 112, 112] 64\n", " ReLU-7 [-1, 32, 112, 112] 0\n", " Conv2d-8 [-1, 64, 112, 112] 2,048\n", " BatchNorm2d-9 [-1, 64, 112, 112] 128\n", " ReLU-10 [-1, 64, 112, 112] 0\n", " base_conv-11 [-1, 64, 112, 112] 0\n", " BatchNorm2d-12 [-1, 64, 112, 112] 128\n", " ReLU-13 [-1, 64, 112, 112] 0\n", " depth_wise-14 [-1, 64, 112, 112] 0\n", " Conv2d-15 [-1, 64, 56, 56] 576\n", " BatchNorm2d-16 [-1, 64, 56, 56] 128\n", " ReLU-17 [-1, 64, 56, 56] 0\n", " Conv2d-18 [-1, 128, 56, 56] 8,192\n", " BatchNorm2d-19 [-1, 128, 56, 56] 256\n", " ReLU-20 [-1, 128, 56, 56] 0\n", " base_conv-21 [-1, 128, 56, 56] 0\n", " BatchNorm2d-22 [-1, 128, 56, 56] 256\n", " ReLU-23 [-1, 128, 56, 56] 0\n", " depth_wise-24 [-1, 128, 56, 56] 0\n", " Conv2d-25 [-1, 128, 56, 56] 1,152\n", " BatchNorm2d-26 [-1, 128, 56, 56] 256\n", " ReLU-27 [-1, 128, 56, 56] 0\n", " Conv2d-28 [-1, 128, 56, 56] 16,384\n", " BatchNorm2d-29 [-1, 128, 56, 56] 256\n", " ReLU-30 [-1, 128, 56, 56] 0\n", " base_conv-31 [-1, 128, 56, 56] 0\n", " BatchNorm2d-32 [-1, 128, 56, 56] 256\n", " ReLU-33 [-1, 128, 56, 56] 0\n", " depth_wise-34 [-1, 128, 56, 56] 0\n", " Conv2d-35 [-1, 128, 28, 28] 1,152\n", " BatchNorm2d-36 [-1, 128, 28, 28] 256\n", " ReLU-37 [-1, 128, 28, 28] 0\n", " Conv2d-38 [-1, 256, 28, 28] 32,768\n", " BatchNorm2d-39 [-1, 256, 28, 28] 512\n", " ReLU-40 [-1, 256, 28, 28] 0\n", " base_conv-41 [-1, 256, 28, 28] 0\n", " BatchNorm2d-42 [-1, 256, 28, 28] 512\n", " ReLU-43 [-1, 256, 28, 28] 0\n", " depth_wise-44 [-1, 256, 28, 28] 0\n", " Conv2d-45 [-1, 256, 28, 28] 2,304\n", " BatchNorm2d-46 [-1, 256, 28, 28] 512\n", " ReLU-47 [-1, 256, 28, 28] 0\n", " Conv2d-48 [-1, 256, 28, 28] 65,536\n", " BatchNorm2d-49 [-1, 256, 28, 28] 512\n", " ReLU-50 [-1, 256, 28, 28] 0\n", " base_conv-51 [-1, 256, 28, 28] 0\n", " BatchNorm2d-52 [-1, 256, 28, 28] 512\n", " ReLU-53 [-1, 256, 28, 28] 0\n", " depth_wise-54 [-1, 256, 28, 28] 0\n", " Conv2d-55 [-1, 256, 14, 14] 2,304\n", " BatchNorm2d-56 [-1, 256, 14, 14] 512\n", " ReLU-57 [-1, 256, 14, 14] 0\n", " Conv2d-58 [-1, 512, 14, 14] 131,072\n", " BatchNorm2d-59 [-1, 512, 14, 14] 1,024\n", " ReLU-60 [-1, 512, 14, 14] 0\n", " base_conv-61 [-1, 512, 14, 14] 0\n", " BatchNorm2d-62 [-1, 512, 14, 14] 1,024\n", " ReLU-63 [-1, 512, 14, 14] 0\n", " depth_wise-64 [-1, 512, 14, 14] 0\n", " Conv2d-65 [-1, 512, 14, 14] 4,608\n", " BatchNorm2d-66 [-1, 512, 14, 14] 1,024\n", " ReLU-67 [-1, 512, 14, 14] 0\n", " Conv2d-68 [-1, 512, 14, 14] 262,144\n", " BatchNorm2d-69 [-1, 512, 14, 14] 1,024\n", " ReLU-70 [-1, 512, 14, 14] 0\n", " base_conv-71 [-1, 512, 14, 14] 0\n", " BatchNorm2d-72 [-1, 512, 14, 14] 1,024\n", " ReLU-73 [-1, 512, 14, 14] 0\n", " depth_wise-74 [-1, 512, 14, 14] 0\n", " Conv2d-75 [-1, 512, 7, 7] 4,608\n", " BatchNorm2d-76 [-1, 512, 7, 7] 1,024\n", " ReLU-77 [-1, 512, 7, 7] 0\n", " Conv2d-78 [-1, 1024, 7, 7] 524,288\n", " BatchNorm2d-79 [-1, 1024, 7, 7] 2,048\n", " ReLU-80 [-1, 1024, 7, 7] 0\n", " base_conv-81 [-1, 1024, 7, 7] 0\n", " BatchNorm2d-82 [-1, 1024, 7, 7] 2,048\n", " ReLU-83 [-1, 1024, 7, 7] 0\n", " depth_wise-84 [-1, 1024, 7, 7] 0\n", " Conv2d-85 [-1, 1024, 4, 4] 9,216\n", " BatchNorm2d-86 [-1, 1024, 4, 4] 2,048\n", " ReLU-87 [-1, 1024, 4, 4] 0\n", " Conv2d-88 [-1, 1024, 4, 4] 1,048,576\n", " BatchNorm2d-89 [-1, 1024, 4, 4] 2,048\n", " ReLU-90 [-1, 1024, 4, 4] 0\n", " base_conv-91 [-1, 1024, 4, 4] 0\n", " BatchNorm2d-92 [-1, 1024, 4, 4] 2,048\n", " ReLU-93 [-1, 1024, 4, 4] 0\n", " depth_wise-94 [-1, 1024, 4, 4] 0\n", "AdaptiveAvgPool2d-95 [-1, 1024, 1, 1] 0\n", " Linear-96 [-1, 10] 10,250\n", " Softmax-97 [-1, 10] 0\n", "================================================================\n", "Total params: 2,149,834\n", "Trainable params: 2,149,834\n", "Non-trainable params: 0\n", "----------------------------------------------------------------\n", "Input size (MB): 0.57\n", "Forward/backward pass size (MB): 167.97\n", "Params size (MB): 8.20\n", "Estimated Total Size (MB): 176.75\n", "----------------------------------------------------------------\n" ] } ], "source": [ "summary(model, input_size=(3, 224, 224))" ] }, { "cell_type": "markdown", "id": "c47e90e3", "metadata": {}, "source": [ "总结:感觉论文中给的模型结果表有误,修改位置如下。" ] }, { "cell_type": "markdown", "id": "71074fab", "metadata": {}, "source": [ "" ] }, { "cell_type": "code", "execution_count": null, "id": "720904d6", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 5 }