You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

683 lines
26 KiB
Plaintext

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"id": "06b0d2c9",
"metadata": {},
"source": [
"# <center> SSD算法复现"
]
},
{
"cell_type": "markdown",
"id": "30e46dc6",
"metadata": {},
"source": [
"在上一课程vgg16中我们复现了论文中的算法。在SSD算法中还不能直接拿vgg16作为特征提取器我们需要修改一下模型。"
]
},
{
"cell_type": "markdown",
"id": "8d1c9081",
"metadata": {},
"source": [
"## 1.算法结构"
]
},
{
"cell_type": "markdown",
"id": "03d9cb10",
"metadata": {},
"source": [
"<img src='ssd.png' width=60% align=center>"
]
},
{
"cell_type": "markdown",
"id": "9fe1a1cb",
"metadata": {},
"source": [
"<img src='vgg_detail.png' width=70% align=center>"
]
},
{
"cell_type": "markdown",
"id": "8df5108a",
"metadata": {},
"source": [
"1.输入不再是224在SSD算法中可以选择300或者512称之为SSD300、SSD512。 \n",
"2.第3个最大池化层使用的ceil而不是floor其他不变。具体使用方法请参考nn.MaxPool2d。在conv3_3卷积后的特征图尺寸是$75\\times75$使用ceil保证特征图的维度是$38\\times38$的偶数,否则是$37\\times37$的奇数,这样做的目的是方便处理。 \n",
"3.把第5个最大池化层的核大小改成3步长改成1。 \n",
"4.第8层全连接层不要对全连接层6、7改成卷积。 \n",
"5.我们只需要关注conv4_3和conv7就可以了。"
]
},
{
"cell_type": "markdown",
"id": "5fb20620",
"metadata": {},
"source": [
"<img src=\"mod_vgg16.png\" width=80% align=center>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9bef844f",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torchsummary import summary\n",
"from torch.utils.data import DataLoader\n",
"from torchvision.datasets import FashionMNIST\n",
"from torchvision import models"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1bda593e",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ebfca277",
"metadata": {},
"outputs": [],
"source": [
"vgg_cfg = [(3, 3, 64), (3, 64, 64), 'M', (3, 64, 128), (3, 128, 128),\n",
" 'M', (3, 128, 256), (3, 256, 256), (3, 256, 256), 'CM',\n",
" (3, 256, 512), (3, 512, 512), (3, 512, 512), 'M',(3, 512, 512),\n",
" (3, 512, 512), (3, 512, 512), '3M1', (3, 512, 1024, 6, 6), (1, 1024, 1024)\n",
" ]\n",
"#M:2x2 2 floor mode-----CM: 2x2 2 ceil mode-----3M1: 3x3 1 floor mode"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e103335f",
"metadata": {},
"outputs": [],
"source": [
"class BaseConv(nn.Module):\n",
" def __init__(self, kernel_size, in_channels, out_channels, padding=1, dilation=1, stride=1, act=True):\n",
" super(BaseConv, self).__init__()\n",
" self.conv = nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=padding, dilation=dilation, stride=stride)\n",
" self.relu = nn.ReLU()\n",
" self.act = act\n",
" def forward(self, x):\n",
" if self.act:\n",
" return self.relu(self.conv(x))\n",
" else:\n",
" return self.conv(x)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "125c171c",
"metadata": {},
"outputs": [],
"source": [
"class VGG16(nn.Module):\n",
" def __init__(self):\n",
" super(VGG16, self).__init__()\n",
" self.vgg_cfg = vgg_cfg\n",
" self.seq1 = self.net1()#获取conv4_3\n",
" self.seq2 = self.net2()#获取conv7\n",
" \n",
" def net1(self):\n",
" sequential = []\n",
" for c in self.vgg_cfg[: 13]:\n",
" if isinstance(c, tuple):\n",
" sequential.append(BaseConv(c[0], c[1], c[2], act=True))\n",
" else:\n",
" if c == 'M':\n",
" sequential.append(nn.MaxPool2d(kernel_size=2, stride=2))\n",
" else:#CM\n",
" sequential.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True))\n",
" return nn.Sequential(*sequential)\n",
" \n",
" def net2(self):\n",
" sequential = []\n",
" for c in self.vgg_cfg[13: ]:\n",
" if isinstance(c, tuple):\n",
" if len(c) == 3:\n",
" sequential.append(BaseConv(c[0], c[1], c[2], act=True))\n",
" else:\n",
" sequential.append(BaseConv(c[0], c[1], c[2], c[3], c[4], act=True))\n",
" else:\n",
" if c == 'M':\n",
" sequential.append(nn.MaxPool2d(kernel_size=2, stride=2))\n",
" else:#3M1\n",
" sequential.append(nn.MaxPool2d(kernel_size=3, stride=1))\n",
" return nn.Sequential(*sequential)\n",
" \n",
" def forward(self, x):\n",
" conv4_3 = self.seq1(x)#图中的第4组中的第3个卷积\n",
" conv7 = self.seq2(conv4_3)#图中fc7替换的conv7\n",
" return conv4_3, conv7\n",
" \n",
" def load_pretrained_params(self):\n",
" current_dict = self.state_dict()#获取当前模型的名字和层\n",
" current_names = list(current_dict.keys())\n",
" pretrained_dict = models.vgg16(weights=models.VGG16_Weights.DEFAULT).state_dict()#pretrained已经改成weights并且可以指定模型的版本\n",
" pretrained_names = list(pretrained_dict.keys())\n",
" #下面代码可以理解为迁移学习,把模型在其他数据集训练得到的参数赋给自己的模型\n",
" for i in range(len(pretrained_names[: -4])):#原VGG16中除掉后面两层全连接只保留到fc6之前的fc6fc7还要转化为conv6conv7\n",
" current_dict[current_names[i]] = pretrained_dict[pretrained_names[i]]\n",
" #fc6\n",
" fc6_weights = pretrained_dict['classifier.0.weight'].view(4096, 512, 7, 7)#(4096, 25088)--->(4096, 512, 7, 7)\n",
" fc6_bias = pretrained_dict['classifier.0.bias']#(4096)\n",
" #fc7\n",
" fc7_weights = pretrained_dict['classifier.3.weight'].view(4096, 4096, 1, 1)#(4096, 4096)--->(4096, 4096, 1, 1)\n",
" fc7_bias = pretrained_dict['classifier.3.bias']#(4096)\n",
" #subsample下采样conv6, conv7\n",
" #conv6\n",
" conv6_weights = self.decimate(fc6_weights, m=[4, None, 3, 3])#4096个通道采样(0, 3, 7....)通道7x7采样3x3(0, 2, 5),(4096,512,7,7-->(1024, 512, 3, 3)\n",
" conv6_bias = self.decimate(fc6_bias, m=[4])#4096个通道采样4个通道(4096)-->(1024)\n",
" #conv7\n",
" conv7_weights = self.decimate(fc7_weights, m=[4, 4, None, None])#(4096, 4096, 1, 1)-->(1024, 1024, 1, 1)\n",
" conv7_bias = self.decimate(fc7_bias, m=[4])#4096个通道采样4个通道(4096)-->(1024)\n",
" #将采样后的权重复制给现在的网络conv6conv7\n",
" current_dict['seq2.5.conv.weight'] = conv6_weights\n",
" current_dict['seq2.5.conv.bias'] = conv6_bias\n",
" current_dict['seq2.6.conv.weight'] = conv7_weights\n",
" current_dict[ 'seq2.6.conv.bias'] = conv7_bias\n",
" self.load_state_dict(current_dict)\n",
" print(\"pretrained params load finished!\")\n",
" \n",
" def decimate(self, tensor, m):\n",
" \"\"\"\n",
" Decimate a tensor by a factor 'm', i.e. downsample by keeping every 'm'th value.\n",
"\n",
" This is used when we convert FC layers to equivalent Convolutional layers, BUT of a smaller size.\n",
"\n",
" :param tensor: tensor to be decimated\n",
" :param m: list of decimation factors for each dimension of the tensor; None if not to be decimated along a dimension\n",
" :return: decimated tensor\n",
" \"\"\"\n",
" assert tensor.dim() == len(m)#判断维度是否正确\n",
" for d in range(tensor.dim()):\n",
" if m[d] is not None:\n",
" tensor = tensor.index_select(dim=d,\n",
" index=torch.arange(start=0, end=tensor.size(d), step=m[d]).long())#根据索引取出tensor\n",
" return tensor"
]
},
{
"cell_type": "markdown",
"id": "ea144c5a",
"metadata": {},
"source": [
"vgg16每一层名字如下一共32个每一层包含权重和偏置所以是16层。序号表示卷积位于整个模型的第几层比如0过了是2因为在0层卷积后面还有一个激活函数relu(),其他层下标都是这样计算。"
]
},
{
"cell_type": "markdown",
"id": "112b7c0c",
"metadata": {},
"source": [
"官网的模型名字:['features.0.weight', 'features.0.bias', 'features.2.weight', 'features.2.bias', 'features.5.weight', 'features.5.bias', 'features.7.weight', 'features.7.bias', 'features.10.weight', 'features.10.bias', 'features.12.weight', 'features.12.bias', 'features.14.weight', 'features.14.bias', 'features.17.weight', 'features.17.bias', 'features.19.weight', 'features.19.bias', 'features.21.weight', 'features.21.bias', 'features.24.weight', 'features.24.bias', 'features.26.weight', 'features.26.bias', 'features.28.weight', 'features.28.bias', 'classifier.0.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias']"
]
},
{
"cell_type": "markdown",
"id": "11c2d3f8",
"metadata": {},
"source": [
"conv6, conv7对应的层为:'seq2.5.conv.weight', 'seq2.5.conv.bias', 'seq2.6.conv.weight', 'seq2.6.conv.bias'"
]
},
{
"cell_type": "markdown",
"id": "07564988",
"metadata": {},
"source": [
"自己的模型名字:['seq1.0.conv.weight', 'seq1.0.conv.bias', 'seq1.1.conv.weight', 'seq1.1.conv.bias', 'seq1.3.conv.weight', 'seq1.3.conv.bias', 'seq1.4.conv.weight', 'seq1.4.conv.bias', 'seq1.6.conv.weight', 'seq1.6.conv.bias', 'seq1.7.conv.weight', 'seq1.7.conv.bias', 'seq1.8.conv.weight', 'seq1.8.conv.bias', 'seq1.10.conv.weight', 'seq1.10.conv.bias', 'seq1.11.conv.weight', 'seq1.11.conv.bias', 'seq1.12.conv.weight', 'seq1.12.conv.bias', 'seq2.1.conv.weight', 'seq2.1.conv.bias', 'seq2.2.conv.weight', 'seq2.2.conv.bias', 'seq2.3.conv.weight', 'seq2.3.conv.bias', 'seq2.5.conv.weight', 'seq2.5.conv.bias', 'seq2.6.conv.weight', 'seq2.6.conv.bias']"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "883d01b4",
"metadata": {},
"outputs": [],
"source": [
"cls = VGG16()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f225db7b",
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn((1, 3, 300, 300))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "12fafd7a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 512, 38, 38]), torch.Size([1, 1024, 19, 19]))"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cls(x)[0].shape, cls(x)[1].shape"
]
},
{
"cell_type": "markdown",
"id": "b56f3146",
"metadata": {},
"source": [
"除了VGG16作为基础模型以外在最后一层还会增加辅助卷积神经网络。其目的是为了获取conv8_2conv9_2 conv10_2conv11_2的特征图。"
]
},
{
"cell_type": "markdown",
"id": "0e95f3d0",
"metadata": {},
"source": [
"<img src=\"auxiliary_conv.png\" width=80% align=center>"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "9cf3dd8e",
"metadata": {},
"outputs": [],
"source": [
"auc_cfg = [(1, 1024, 256, 0, 1), (3, 256, 512, 1, 2), (1, 512, 128, 0, 1),\n",
" (3, 128, 256, 1, 2), (1, 256, 128, 0, 1), (3, 128, 256, 0, 1),\n",
" (1, 256, 128, 0, 1), (3, 128, 256, 0, 1)]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "81853b45",
"metadata": {},
"outputs": [],
"source": [
"class AuxiliaryConv(nn.Module):\n",
" def __init__(self):\n",
" super(AuxiliaryConv, self).__init__()\n",
" self.auc_cfg = auc_cfg\n",
" self.fm = []\n",
" self.net()\n",
" self.init_param()#初始化参数\n",
" \n",
" def net(self):\n",
" sequential = []\n",
" for i in range(len(self.auc_cfg)):\n",
" c = self.auc_cfg[i]#获取参数\n",
" sequential.append(BaseConv(c[0], c[1], c[2], padding=c[3], stride=c[4], act=True))\n",
" if i % 2 == 1:#每2个为1组\n",
" self.fm.append(nn.Sequential(*sequential))\n",
" sequential = []\n",
" \n",
" def forward(self, conv7):\n",
" conv8_2 = self.fm[0](conv7)\n",
" conv9_2 = self.fm[1](conv8_2)\n",
" conv10_2 = self.fm[2](conv9_2)\n",
" conv11_2 = self.fm[3](conv10_2)\n",
" return conv8_2, conv9_2, conv10_2, conv11_2\n",
" \n",
" def init_param(self):\n",
" \"\"\"初始化参数\"\"\"\n",
" for c in self.children():\n",
" if isinstance(c, nn.Conv2d):\n",
" nn.init.xavier_normal_(c.weight)#初始化参数方法一般使用最后带_方法\n",
" nn.init.constant_(c.bias, 0.)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2c185733",
"metadata": {},
"outputs": [],
"source": [
"auc = AuxiliaryConv()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "fc1cdf31",
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn((1, 1024, 19, 19))#VGG16最后一层卷积特征图维度"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "287755f2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 512, 10, 10]),\n",
" torch.Size([1, 256, 5, 5]),\n",
" torch.Size([1, 256, 3, 3]),\n",
" torch.Size([1, 256, 1, 1]))"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"auc(x)[0].shape, auc(x)[1].shape, auc(x)[2].shape, auc(x)[3].shape"
]
},
{
"cell_type": "markdown",
"id": "fac84e92",
"metadata": {},
"source": [
"<img src=\"boxes.png\" width=70% align=center>"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "82d4527c",
"metadata": {},
"outputs": [],
"source": [
"num_classes = 20#20个类别\n",
"coords = 4 #4个坐标"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "d8880b00",
"metadata": {},
"outputs": [],
"source": [
"pred_cfg = [(3, 512, 4 * coords), (3, 1024, 6 * coords), (3, 512, 6 * coords),\n",
" (3, 256, 6 * coords), (3, 256, 4 * coords), (3, 256, 4 * coords),\n",
" (3, 512, 4 * num_classes), (3, 1024, 6 * num_classes), (3, 512, 6 * num_classes),\n",
" (3, 256, 6 * num_classes), (3, 256, 4 * num_classes), (3, 256, 4 * num_classes)]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "17aa68e1",
"metadata": {},
"outputs": [],
"source": [
"class Prediction(nn.Module):\n",
" def __init__(self, ):\n",
" super(Prediction, self).__init__()\n",
" self.num_classes = num_classes\n",
" self.pred_cfg = pred_cfg\n",
" self.fm = []\n",
" self.net()\n",
" self.init_param()\n",
" \n",
" def net(self):\n",
" for c in pred_cfg:\n",
" self.fm.append(BaseConv(c[0], c[1], c[2], act=False))\n",
" \n",
" def init_param(self):\n",
" \"\"\"初始化参数\"\"\"\n",
" for c in self.children():\n",
" if isinstance(c, nn.Conv2d):\n",
" nn.init.xavier_normal_(c.weight)#初始化参数方法一般使用最后带_方法\n",
" nn.init.constant_(c.bias, 0.)\n",
" \n",
" def forward(self, conv4_3, conv7, conv8_2, conv9_2, conv10_2, conv11_2):\n",
" batch_size = conv4_3.size(0)#获取批次\n",
" #conv4_3坐标处理\n",
" loc_conv4_3 = self.fm[0](conv4_3)#(-1, 16, 38, 38)\n",
" loc_conv4_3 = loc_conv4_3.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 38, 38, 16)\n",
" loc_conv4_3 = loc_conv4_3.view(batch_size, -1, 4)#(N, 5776, 4)\n",
" #conv7坐标处理\n",
" loc_conv7 = self.fm[1](conv7)#(-1, 24, 19, 19)\n",
" loc_conv7 = loc_conv7.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 19, 19, 24)\n",
" loc_conv7 = loc_conv7.view(batch_size, -1, 4)#(N, 2166, 4)\n",
" #conv8_2坐标处理\n",
" loc_conv8_2 = self.fm[2](conv8_2)#(-1, 24, 10, 10)\n",
" loc_conv8_2 = loc_conv8_2.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 10, 10, 24)\n",
" loc_conv8_2 = loc_conv8_2.view(batch_size, -1, 4)#(N, 600, 4)\n",
" #conv9_2坐标处理\n",
" loc_conv9_2 = self.fm[3](conv9_2)#(-1, 24, 5, 5)\n",
" loc_conv9_2 = loc_conv9_2.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 5, 5, 24)\n",
" loc_conv9_2 = loc_conv9_2.view(batch_size, -1, 4)#(N, 150, 4)\n",
" #conv10_2坐标处理\n",
" loc_conv10_2 = self.fm[4](conv10_2)#(-1, 16, 3, 3)\n",
" loc_conv10_2 = loc_conv10_2.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 3, 3, 16)\n",
" loc_conv10_2 = loc_conv10_2.view(batch_size, -1, 4)#(N, 36, 4)\n",
" #conv11_2坐标处理\n",
" loc_conv11_2 = self.fm[5](conv11_2)#(-1, 16, 1, 1)\n",
" loc_conv11_2 = loc_conv11_2.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 1, 1, 16)\n",
" loc_conv11_2 = loc_conv11_2.view(batch_size, -1, 4)#(N, 4, 4)\n",
" locations = torch.cat([loc_conv4_3, loc_conv7, loc_conv8_2, loc_conv9_2, loc_conv10_2, loc_conv11_2], dim=1)\n",
" \n",
" \n",
" #conv4_3类别处理\n",
" cls_conv4_3 = self.fm[6](conv4_3)#(-1, 4 * num_classes, 38, 38)\n",
" cls_conv4_3 = cls_conv4_3.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 38, 38, 4 * num_classes)\n",
" cls_conv4_3 = cls_conv4_3.view(batch_size, -1, num_classes)#(N, 5776, num_classes)\n",
" #conv7类别处理\n",
" cls_conv7 = self.fm[7](conv7)#(-1, 6 * num_classes, 19, 19)\n",
" cls_conv7 = cls_conv7.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 19, 19, 6 * num_classes)\n",
" cls_conv7 = cls_conv7.view(batch_size, -1, num_classes)#(N, 2166, num_classes)\n",
" #conv8_2类别处理\n",
" cls_conv8_2 = self.fm[8](conv8_2)#(-1, 6 * num_classes, 10, 10)\n",
" cls_conv8_2 = cls_conv8_2.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 10, 10, 6 * num_classes)\n",
" cls_conv8_2 = cls_conv8_2.view(batch_size, -1, num_classes)#(N, 600, num_classes)\n",
" #conv9_2类别处理\n",
" cls_conv9_2 = self.fm[9](conv9_2)#(-1, 6 * num_classes, 5, 5)\n",
" cls_conv9_2 = cls_conv9_2.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 5, 5, 6 * num_classes)\n",
" cls_conv9_2 = cls_conv9_2.view(batch_size, -1, num_classes)#(N, 150, num_classes)\n",
" #conv10_2类别处理\n",
" cls_conv10_2 = self.fm[10](conv10_2)#(-1, 4 * num_classes, 3, 3)\n",
" cls_conv10_2 = cls_conv10_2.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 3, 3, 4 * num_classes)\n",
" cls_conv10_2 = cls_conv10_2.view(batch_size, -1, num_classes)#(N, 36, num_classes)\n",
" #conv11_2类别处理\n",
" cls_conv11_2 = self.fm[11](conv11_2)#(-1, 4 * num_classes, 1, 1)\n",
" cls_conv11_2 = cls_conv11_2.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 1, 1, 4 * num_classes)\n",
" cls_conv11_2 = cls_conv11_2.view(batch_size, -1, num_classes)#(N, 4, num_classes)\n",
" locations = torch.cat([loc_conv4_3, loc_conv7, loc_conv8_2, loc_conv9_2, loc_conv10_2, loc_conv11_2], dim=1)\n",
" classes = torch.cat([cls_conv4_3, cls_conv7, cls_conv8_2, cls_conv9_2, cls_conv10_2, cls_conv11_2], dim=1)\n",
" return locations, classes"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "dbe28eff",
"metadata": {},
"outputs": [],
"source": [
"p = Prediction()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "5ecde33e",
"metadata": {},
"outputs": [],
"source": [
"conv4_3 = torch.randn((1, 512, 38, 38))\n",
"conv7 = torch.randn((1, 1024, 19, 19))\n",
"conv8_2 = torch.randn((1, 512, 10, 10))\n",
"conv9_2 = torch.randn((1, 256, 5, 5))\n",
"conv10_2 = torch.randn((1, 256, 3, 3))\n",
"conv11_2 = torch.randn((1, 256, 1, 1))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "579114f2",
"metadata": {},
"outputs": [],
"source": [
"out = p(conv4_3, conv7, conv8_2, conv9_2, conv10_2, conv11_2)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "012ac090",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 8732, 4]), torch.Size([1, 8732, 20]))"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out[0].shape, out[1].shape"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "86d16dc4",
"metadata": {},
"outputs": [],
"source": [
"x = nn.Parameter(torch.FloatTensor(1,100))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "5f306349",
"metadata": {},
"outputs": [],
"source": [
"fm_dims = {'conv4_3': 38,\n",
" 'conv7': 19,\n",
" 'conv8_2': 10,\n",
" 'conv9_2': 5,\n",
" 'conv10_2': 3,\n",
" 'conv11_2': 1}\n",
"boxes_scale = {'conv4_3': 0.1,\n",
" 'conv7': 0.2, \n",
" 'conv8_2': 0.375,\n",
" 'conv9_2': 0.55,\n",
" 'conv10_2': 0.725,\n",
" 'conv11_2': 0.9}\n",
"aspect_ratios = {'conv4_3': [1., 2., 0.5],\n",
" 'conv7': [1., 2., 3., 0.5, 0.333],\n",
" 'conv8_2': [1., 2., 3., 0.5, 0.333],\n",
" 'conv9_2': [1., 2., 3., 0.5, 0.333],\n",
" 'conv10_2': [1., 2., 0.5],\n",
" 'conv11_2': [1., 2., 0.5]}"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "ba54e968",
"metadata": {},
"outputs": [
{
"ename": "SyntaxError",
"evalue": "invalid character in identifier (3965864224.py, line 12)",
"output_type": "error",
"traceback": [
"\u001b[1;36m File \u001b[1;32m\"C:\\Users\\Stark-lin\\AppData\\Local\\Temp\\ipykernel_6056\\3965864224.py\"\u001b[1;36m, line \u001b[1;32m12\u001b[0m\n\u001b[1;33m def forward(self image):\u001b[0m\n\u001b[1;37m ^\u001b[0m\n\u001b[1;31mSyntaxError\u001b[0m\u001b[1;31m:\u001b[0m invalid character in identifier\n"
]
}
],
"source": [
"class SSD(nn.Module):\n",
" def __init__(self):\n",
" super(SSD, self).__init__()\n",
" self.fm_dims = fm_dims\n",
" self.boxes_scale = boxes_scale\n",
" self.aspect_ratios = aspect_ratios\n",
" self.base = VGG16()\n",
" self.auc = AuxiliaryConv()\n",
" self.prediction = Prediction()\n",
" self.rescale = nn.Parameter(torch.FloatTensor(1, 512, 1, 1))#对conv4_3进行缩放低特征提取器的规模更大使用L2 norm归一化和其他层保持一个scale\n",
" nn.init.constant_(self.rescale, 20)#全是20后面通过反向传播训练nn.Parameter是可以更新的\n",
" \n",
" def forward(self image):\n",
" conv4_3, conv7 = self.base(image)#(N, 512, 38, 38),(N, 1024, 19, 19)\n",
" norm = conv4_3.pow(2).sum(dim=1, keepdim=True).sqrt()#(N, 1, 38, 38)\n",
" conv4_3 = conv4_3 / norm #(N, 512, 38, 38)\n",
" conv4_3 = conv4_3 * self.rescale\n",
" \n",
" conv8_2, conv_9_2, conv10_2, conv11_2 = self.auc(conv_7)#(N, 512, 10, 10),(N, 256, 5, 5),(N, 256, 3, 3), (N, 256, 1, 1)\n",
" locations, classes = self.prediction(conv4_3, conv7, conv8_2, conv9_2, conv10_2, conv11_2)\n",
" return locations, classes\n",
" \n",
" def get_prior_boxes(self):\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b5cc276",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}