You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

683 lines
26 KiB
Plaintext

6 months ago
{
"cells": [
{
"cell_type": "markdown",
"id": "06b0d2c9",
"metadata": {},
"source": [
"# <center> SSD算法复现"
]
},
{
"cell_type": "markdown",
"id": "30e46dc6",
"metadata": {},
"source": [
"在上一课程vgg16中我们复现了论文中的算法。在SSD算法中还不能直接拿vgg16作为特征提取器我们需要修改一下模型。"
]
},
{
"cell_type": "markdown",
"id": "8d1c9081",
"metadata": {},
"source": [
"## 1.算法结构"
]
},
{
"cell_type": "markdown",
"id": "03d9cb10",
"metadata": {},
"source": [
"<img src='ssd.png' width=60% align=center>"
]
},
{
"cell_type": "markdown",
"id": "9fe1a1cb",
"metadata": {},
"source": [
"<img src='vgg_detail.png' width=70% align=center>"
]
},
{
"cell_type": "markdown",
"id": "8df5108a",
"metadata": {},
"source": [
"1.输入不再是224在SSD算法中可以选择300或者512称之为SSD300、SSD512。 \n",
"2.第3个最大池化层使用的ceil而不是floor其他不变。具体使用方法请参考nn.MaxPool2d。在conv3_3卷积后的特征图尺寸是$75\\times75$使用ceil保证特征图的维度是$38\\times38$的偶数,否则是$37\\times37$的奇数,这样做的目的是方便处理。 \n",
"3.把第5个最大池化层的核大小改成3步长改成1。 \n",
"4.第8层全连接层不要对全连接层6、7改成卷积。 \n",
"5.我们只需要关注conv4_3和conv7就可以了。"
]
},
{
"cell_type": "markdown",
"id": "5fb20620",
"metadata": {},
"source": [
"<img src=\"mod_vgg16.png\" width=80% align=center>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9bef844f",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torchsummary import summary\n",
"from torch.utils.data import DataLoader\n",
"from torchvision.datasets import FashionMNIST\n",
"from torchvision import models"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1bda593e",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ebfca277",
"metadata": {},
"outputs": [],
"source": [
"vgg_cfg = [(3, 3, 64), (3, 64, 64), 'M', (3, 64, 128), (3, 128, 128),\n",
" 'M', (3, 128, 256), (3, 256, 256), (3, 256, 256), 'CM',\n",
" (3, 256, 512), (3, 512, 512), (3, 512, 512), 'M',(3, 512, 512),\n",
" (3, 512, 512), (3, 512, 512), '3M1', (3, 512, 1024, 6, 6), (1, 1024, 1024)\n",
" ]\n",
"#M:2x2 2 floor mode-----CM: 2x2 2 ceil mode-----3M1: 3x3 1 floor mode"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e103335f",
"metadata": {},
"outputs": [],
"source": [
"class BaseConv(nn.Module):\n",
" def __init__(self, kernel_size, in_channels, out_channels, padding=1, dilation=1, stride=1, act=True):\n",
" super(BaseConv, self).__init__()\n",
" self.conv = nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels, padding=padding, dilation=dilation, stride=stride)\n",
" self.relu = nn.ReLU()\n",
" self.act = act\n",
" def forward(self, x):\n",
" if self.act:\n",
" return self.relu(self.conv(x))\n",
" else:\n",
" return self.conv(x)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "125c171c",
"metadata": {},
"outputs": [],
"source": [
"class VGG16(nn.Module):\n",
" def __init__(self):\n",
" super(VGG16, self).__init__()\n",
" self.vgg_cfg = vgg_cfg\n",
" self.seq1 = self.net1()#获取conv4_3\n",
" self.seq2 = self.net2()#获取conv7\n",
" \n",
" def net1(self):\n",
" sequential = []\n",
" for c in self.vgg_cfg[: 13]:\n",
" if isinstance(c, tuple):\n",
" sequential.append(BaseConv(c[0], c[1], c[2], act=True))\n",
" else:\n",
" if c == 'M':\n",
" sequential.append(nn.MaxPool2d(kernel_size=2, stride=2))\n",
" else:#CM\n",
" sequential.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True))\n",
" return nn.Sequential(*sequential)\n",
" \n",
" def net2(self):\n",
" sequential = []\n",
" for c in self.vgg_cfg[13: ]:\n",
" if isinstance(c, tuple):\n",
" if len(c) == 3:\n",
" sequential.append(BaseConv(c[0], c[1], c[2], act=True))\n",
" else:\n",
" sequential.append(BaseConv(c[0], c[1], c[2], c[3], c[4], act=True))\n",
" else:\n",
" if c == 'M':\n",
" sequential.append(nn.MaxPool2d(kernel_size=2, stride=2))\n",
" else:#3M1\n",
" sequential.append(nn.MaxPool2d(kernel_size=3, stride=1))\n",
" return nn.Sequential(*sequential)\n",
" \n",
" def forward(self, x):\n",
" conv4_3 = self.seq1(x)#图中的第4组中的第3个卷积\n",
" conv7 = self.seq2(conv4_3)#图中fc7替换的conv7\n",
" return conv4_3, conv7\n",
" \n",
" def load_pretrained_params(self):\n",
" current_dict = self.state_dict()#获取当前模型的名字和层\n",
" current_names = list(current_dict.keys())\n",
" pretrained_dict = models.vgg16(weights=models.VGG16_Weights.DEFAULT).state_dict()#pretrained已经改成weights并且可以指定模型的版本\n",
" pretrained_names = list(pretrained_dict.keys())\n",
" #下面代码可以理解为迁移学习,把模型在其他数据集训练得到的参数赋给自己的模型\n",
" for i in range(len(pretrained_names[: -4])):#原VGG16中除掉后面两层全连接只保留到fc6之前的fc6fc7还要转化为conv6conv7\n",
" current_dict[current_names[i]] = pretrained_dict[pretrained_names[i]]\n",
" #fc6\n",
" fc6_weights = pretrained_dict['classifier.0.weight'].view(4096, 512, 7, 7)#(4096, 25088)--->(4096, 512, 7, 7)\n",
" fc6_bias = pretrained_dict['classifier.0.bias']#(4096)\n",
" #fc7\n",
" fc7_weights = pretrained_dict['classifier.3.weight'].view(4096, 4096, 1, 1)#(4096, 4096)--->(4096, 4096, 1, 1)\n",
" fc7_bias = pretrained_dict['classifier.3.bias']#(4096)\n",
" #subsample下采样conv6, conv7\n",
" #conv6\n",
" conv6_weights = self.decimate(fc6_weights, m=[4, None, 3, 3])#4096个通道采样(0, 3, 7....)通道7x7采样3x3(0, 2, 5),(4096,512,7,7-->(1024, 512, 3, 3)\n",
" conv6_bias = self.decimate(fc6_bias, m=[4])#4096个通道采样4个通道(4096)-->(1024)\n",
" #conv7\n",
" conv7_weights = self.decimate(fc7_weights, m=[4, 4, None, None])#(4096, 4096, 1, 1)-->(1024, 1024, 1, 1)\n",
" conv7_bias = self.decimate(fc7_bias, m=[4])#4096个通道采样4个通道(4096)-->(1024)\n",
" #将采样后的权重复制给现在的网络conv6conv7\n",
" current_dict['seq2.5.conv.weight'] = conv6_weights\n",
" current_dict['seq2.5.conv.bias'] = conv6_bias\n",
" current_dict['seq2.6.conv.weight'] = conv7_weights\n",
" current_dict[ 'seq2.6.conv.bias'] = conv7_bias\n",
" self.load_state_dict(current_dict)\n",
" print(\"pretrained params load finished!\")\n",
" \n",
" def decimate(self, tensor, m):\n",
" \"\"\"\n",
" Decimate a tensor by a factor 'm', i.e. downsample by keeping every 'm'th value.\n",
"\n",
" This is used when we convert FC layers to equivalent Convolutional layers, BUT of a smaller size.\n",
"\n",
" :param tensor: tensor to be decimated\n",
" :param m: list of decimation factors for each dimension of the tensor; None if not to be decimated along a dimension\n",
" :return: decimated tensor\n",
" \"\"\"\n",
" assert tensor.dim() == len(m)#判断维度是否正确\n",
" for d in range(tensor.dim()):\n",
" if m[d] is not None:\n",
" tensor = tensor.index_select(dim=d,\n",
" index=torch.arange(start=0, end=tensor.size(d), step=m[d]).long())#根据索引取出tensor\n",
" return tensor"
]
},
{
"cell_type": "markdown",
"id": "ea144c5a",
"metadata": {},
"source": [
"vgg16每一层名字如下一共32个每一层包含权重和偏置所以是16层。序号表示卷积位于整个模型的第几层比如0过了是2因为在0层卷积后面还有一个激活函数relu(),其他层下标都是这样计算。"
]
},
{
"cell_type": "markdown",
"id": "112b7c0c",
"metadata": {},
"source": [
"官网的模型名字:['features.0.weight', 'features.0.bias', 'features.2.weight', 'features.2.bias', 'features.5.weight', 'features.5.bias', 'features.7.weight', 'features.7.bias', 'features.10.weight', 'features.10.bias', 'features.12.weight', 'features.12.bias', 'features.14.weight', 'features.14.bias', 'features.17.weight', 'features.17.bias', 'features.19.weight', 'features.19.bias', 'features.21.weight', 'features.21.bias', 'features.24.weight', 'features.24.bias', 'features.26.weight', 'features.26.bias', 'features.28.weight', 'features.28.bias', 'classifier.0.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias']"
]
},
{
"cell_type": "markdown",
"id": "11c2d3f8",
"metadata": {},
"source": [
"conv6, conv7对应的层为:'seq2.5.conv.weight', 'seq2.5.conv.bias', 'seq2.6.conv.weight', 'seq2.6.conv.bias'"
]
},
{
"cell_type": "markdown",
"id": "07564988",
"metadata": {},
"source": [
"自己的模型名字:['seq1.0.conv.weight', 'seq1.0.conv.bias', 'seq1.1.conv.weight', 'seq1.1.conv.bias', 'seq1.3.conv.weight', 'seq1.3.conv.bias', 'seq1.4.conv.weight', 'seq1.4.conv.bias', 'seq1.6.conv.weight', 'seq1.6.conv.bias', 'seq1.7.conv.weight', 'seq1.7.conv.bias', 'seq1.8.conv.weight', 'seq1.8.conv.bias', 'seq1.10.conv.weight', 'seq1.10.conv.bias', 'seq1.11.conv.weight', 'seq1.11.conv.bias', 'seq1.12.conv.weight', 'seq1.12.conv.bias', 'seq2.1.conv.weight', 'seq2.1.conv.bias', 'seq2.2.conv.weight', 'seq2.2.conv.bias', 'seq2.3.conv.weight', 'seq2.3.conv.bias', 'seq2.5.conv.weight', 'seq2.5.conv.bias', 'seq2.6.conv.weight', 'seq2.6.conv.bias']"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "883d01b4",
"metadata": {},
"outputs": [],
"source": [
"cls = VGG16()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f225db7b",
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn((1, 3, 300, 300))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "12fafd7a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 512, 38, 38]), torch.Size([1, 1024, 19, 19]))"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cls(x)[0].shape, cls(x)[1].shape"
]
},
{
"cell_type": "markdown",
"id": "b56f3146",
"metadata": {},
"source": [
"除了VGG16作为基础模型以外在最后一层还会增加辅助卷积神经网络。其目的是为了获取conv8_2conv9_2 conv10_2conv11_2的特征图。"
]
},
{
"cell_type": "markdown",
"id": "0e95f3d0",
"metadata": {},
"source": [
"<img src=\"auxiliary_conv.png\" width=80% align=center>"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "9cf3dd8e",
"metadata": {},
"outputs": [],
"source": [
"auc_cfg = [(1, 1024, 256, 0, 1), (3, 256, 512, 1, 2), (1, 512, 128, 0, 1),\n",
" (3, 128, 256, 1, 2), (1, 256, 128, 0, 1), (3, 128, 256, 0, 1),\n",
" (1, 256, 128, 0, 1), (3, 128, 256, 0, 1)]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "81853b45",
"metadata": {},
"outputs": [],
"source": [
"class AuxiliaryConv(nn.Module):\n",
" def __init__(self):\n",
" super(AuxiliaryConv, self).__init__()\n",
" self.auc_cfg = auc_cfg\n",
" self.fm = []\n",
" self.net()\n",
" self.init_param()#初始化参数\n",
" \n",
" def net(self):\n",
" sequential = []\n",
" for i in range(len(self.auc_cfg)):\n",
" c = self.auc_cfg[i]#获取参数\n",
" sequential.append(BaseConv(c[0], c[1], c[2], padding=c[3], stride=c[4], act=True))\n",
" if i % 2 == 1:#每2个为1组\n",
" self.fm.append(nn.Sequential(*sequential))\n",
" sequential = []\n",
" \n",
" def forward(self, conv7):\n",
" conv8_2 = self.fm[0](conv7)\n",
" conv9_2 = self.fm[1](conv8_2)\n",
" conv10_2 = self.fm[2](conv9_2)\n",
" conv11_2 = self.fm[3](conv10_2)\n",
" return conv8_2, conv9_2, conv10_2, conv11_2\n",
" \n",
" def init_param(self):\n",
" \"\"\"初始化参数\"\"\"\n",
" for c in self.children():\n",
" if isinstance(c, nn.Conv2d):\n",
" nn.init.xavier_normal_(c.weight)#初始化参数方法一般使用最后带_方法\n",
" nn.init.constant_(c.bias, 0.)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "2c185733",
"metadata": {},
"outputs": [],
"source": [
"auc = AuxiliaryConv()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "fc1cdf31",
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn((1, 1024, 19, 19))#VGG16最后一层卷积特征图维度"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "287755f2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 512, 10, 10]),\n",
" torch.Size([1, 256, 5, 5]),\n",
" torch.Size([1, 256, 3, 3]),\n",
" torch.Size([1, 256, 1, 1]))"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"auc(x)[0].shape, auc(x)[1].shape, auc(x)[2].shape, auc(x)[3].shape"
]
},
{
"cell_type": "markdown",
"id": "fac84e92",
"metadata": {},
"source": [
"<img src=\"boxes.png\" width=70% align=center>"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "82d4527c",
"metadata": {},
"outputs": [],
"source": [
"num_classes = 20#20个类别\n",
"coords = 4 #4个坐标"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "d8880b00",
"metadata": {},
"outputs": [],
"source": [
"pred_cfg = [(3, 512, 4 * coords), (3, 1024, 6 * coords), (3, 512, 6 * coords),\n",
" (3, 256, 6 * coords), (3, 256, 4 * coords), (3, 256, 4 * coords),\n",
" (3, 512, 4 * num_classes), (3, 1024, 6 * num_classes), (3, 512, 6 * num_classes),\n",
" (3, 256, 6 * num_classes), (3, 256, 4 * num_classes), (3, 256, 4 * num_classes)]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "17aa68e1",
"metadata": {},
"outputs": [],
"source": [
"class Prediction(nn.Module):\n",
" def __init__(self, ):\n",
" super(Prediction, self).__init__()\n",
" self.num_classes = num_classes\n",
" self.pred_cfg = pred_cfg\n",
" self.fm = []\n",
" self.net()\n",
" self.init_param()\n",
" \n",
" def net(self):\n",
" for c in pred_cfg:\n",
" self.fm.append(BaseConv(c[0], c[1], c[2], act=False))\n",
" \n",
" def init_param(self):\n",
" \"\"\"初始化参数\"\"\"\n",
" for c in self.children():\n",
" if isinstance(c, nn.Conv2d):\n",
" nn.init.xavier_normal_(c.weight)#初始化参数方法一般使用最后带_方法\n",
" nn.init.constant_(c.bias, 0.)\n",
" \n",
" def forward(self, conv4_3, conv7, conv8_2, conv9_2, conv10_2, conv11_2):\n",
" batch_size = conv4_3.size(0)#获取批次\n",
" #conv4_3坐标处理\n",
" loc_conv4_3 = self.fm[0](conv4_3)#(-1, 16, 38, 38)\n",
" loc_conv4_3 = loc_conv4_3.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 38, 38, 16)\n",
" loc_conv4_3 = loc_conv4_3.view(batch_size, -1, 4)#(N, 5776, 4)\n",
" #conv7坐标处理\n",
" loc_conv7 = self.fm[1](conv7)#(-1, 24, 19, 19)\n",
" loc_conv7 = loc_conv7.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 19, 19, 24)\n",
" loc_conv7 = loc_conv7.view(batch_size, -1, 4)#(N, 2166, 4)\n",
" #conv8_2坐标处理\n",
" loc_conv8_2 = self.fm[2](conv8_2)#(-1, 24, 10, 10)\n",
" loc_conv8_2 = loc_conv8_2.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 10, 10, 24)\n",
" loc_conv8_2 = loc_conv8_2.view(batch_size, -1, 4)#(N, 600, 4)\n",
" #conv9_2坐标处理\n",
" loc_conv9_2 = self.fm[3](conv9_2)#(-1, 24, 5, 5)\n",
" loc_conv9_2 = loc_conv9_2.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 5, 5, 24)\n",
" loc_conv9_2 = loc_conv9_2.view(batch_size, -1, 4)#(N, 150, 4)\n",
" #conv10_2坐标处理\n",
" loc_conv10_2 = self.fm[4](conv10_2)#(-1, 16, 3, 3)\n",
" loc_conv10_2 = loc_conv10_2.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 3, 3, 16)\n",
" loc_conv10_2 = loc_conv10_2.view(batch_size, -1, 4)#(N, 36, 4)\n",
" #conv11_2坐标处理\n",
" loc_conv11_2 = self.fm[5](conv11_2)#(-1, 16, 1, 1)\n",
" loc_conv11_2 = loc_conv11_2.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 1, 1, 16)\n",
" loc_conv11_2 = loc_conv11_2.view(batch_size, -1, 4)#(N, 4, 4)\n",
" locations = torch.cat([loc_conv4_3, loc_conv7, loc_conv8_2, loc_conv9_2, loc_conv10_2, loc_conv11_2], dim=1)\n",
" \n",
" \n",
" #conv4_3类别处理\n",
" cls_conv4_3 = self.fm[6](conv4_3)#(-1, 4 * num_classes, 38, 38)\n",
" cls_conv4_3 = cls_conv4_3.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 38, 38, 4 * num_classes)\n",
" cls_conv4_3 = cls_conv4_3.view(batch_size, -1, num_classes)#(N, 5776, num_classes)\n",
" #conv7类别处理\n",
" cls_conv7 = self.fm[7](conv7)#(-1, 6 * num_classes, 19, 19)\n",
" cls_conv7 = cls_conv7.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 19, 19, 6 * num_classes)\n",
" cls_conv7 = cls_conv7.view(batch_size, -1, num_classes)#(N, 2166, num_classes)\n",
" #conv8_2类别处理\n",
" cls_conv8_2 = self.fm[8](conv8_2)#(-1, 6 * num_classes, 10, 10)\n",
" cls_conv8_2 = cls_conv8_2.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 10, 10, 6 * num_classes)\n",
" cls_conv8_2 = cls_conv8_2.view(batch_size, -1, num_classes)#(N, 600, num_classes)\n",
" #conv9_2类别处理\n",
" cls_conv9_2 = self.fm[9](conv9_2)#(-1, 6 * num_classes, 5, 5)\n",
" cls_conv9_2 = cls_conv9_2.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 5, 5, 6 * num_classes)\n",
" cls_conv9_2 = cls_conv9_2.view(batch_size, -1, num_classes)#(N, 150, num_classes)\n",
" #conv10_2类别处理\n",
" cls_conv10_2 = self.fm[10](conv10_2)#(-1, 4 * num_classes, 3, 3)\n",
" cls_conv10_2 = cls_conv10_2.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 3, 3, 4 * num_classes)\n",
" cls_conv10_2 = cls_conv10_2.view(batch_size, -1, num_classes)#(N, 36, num_classes)\n",
" #conv11_2类别处理\n",
" cls_conv11_2 = self.fm[11](conv11_2)#(-1, 4 * num_classes, 1, 1)\n",
" cls_conv11_2 = cls_conv11_2.permute(0, 2, 3, 1).contiguous()#contiguous使得内存中数据改变(-1, 1, 1, 4 * num_classes)\n",
" cls_conv11_2 = cls_conv11_2.view(batch_size, -1, num_classes)#(N, 4, num_classes)\n",
" locations = torch.cat([loc_conv4_3, loc_conv7, loc_conv8_2, loc_conv9_2, loc_conv10_2, loc_conv11_2], dim=1)\n",
" classes = torch.cat([cls_conv4_3, cls_conv7, cls_conv8_2, cls_conv9_2, cls_conv10_2, cls_conv11_2], dim=1)\n",
" return locations, classes"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "dbe28eff",
"metadata": {},
"outputs": [],
"source": [
"p = Prediction()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "5ecde33e",
"metadata": {},
"outputs": [],
"source": [
"conv4_3 = torch.randn((1, 512, 38, 38))\n",
"conv7 = torch.randn((1, 1024, 19, 19))\n",
"conv8_2 = torch.randn((1, 512, 10, 10))\n",
"conv9_2 = torch.randn((1, 256, 5, 5))\n",
"conv10_2 = torch.randn((1, 256, 3, 3))\n",
"conv11_2 = torch.randn((1, 256, 1, 1))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "579114f2",
"metadata": {},
"outputs": [],
"source": [
"out = p(conv4_3, conv7, conv8_2, conv9_2, conv10_2, conv11_2)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "012ac090",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 8732, 4]), torch.Size([1, 8732, 20]))"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out[0].shape, out[1].shape"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "86d16dc4",
"metadata": {},
"outputs": [],
"source": [
"x = nn.Parameter(torch.FloatTensor(1,100))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "5f306349",
"metadata": {},
"outputs": [],
"source": [
"fm_dims = {'conv4_3': 38,\n",
" 'conv7': 19,\n",
" 'conv8_2': 10,\n",
" 'conv9_2': 5,\n",
" 'conv10_2': 3,\n",
" 'conv11_2': 1}\n",
"boxes_scale = {'conv4_3': 0.1,\n",
" 'conv7': 0.2, \n",
" 'conv8_2': 0.375,\n",
" 'conv9_2': 0.55,\n",
" 'conv10_2': 0.725,\n",
" 'conv11_2': 0.9}\n",
"aspect_ratios = {'conv4_3': [1., 2., 0.5],\n",
" 'conv7': [1., 2., 3., 0.5, 0.333],\n",
" 'conv8_2': [1., 2., 3., 0.5, 0.333],\n",
" 'conv9_2': [1., 2., 3., 0.5, 0.333],\n",
" 'conv10_2': [1., 2., 0.5],\n",
" 'conv11_2': [1., 2., 0.5]}"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "ba54e968",
"metadata": {},
"outputs": [
{
"ename": "SyntaxError",
"evalue": "invalid character in identifier (3965864224.py, line 12)",
"output_type": "error",
"traceback": [
"\u001b[1;36m File \u001b[1;32m\"C:\\Users\\Stark-lin\\AppData\\Local\\Temp\\ipykernel_6056\\3965864224.py\"\u001b[1;36m, line \u001b[1;32m12\u001b[0m\n\u001b[1;33m def forward(self image):\u001b[0m\n\u001b[1;37m ^\u001b[0m\n\u001b[1;31mSyntaxError\u001b[0m\u001b[1;31m:\u001b[0m invalid character in identifier\n"
]
}
],
"source": [
"class SSD(nn.Module):\n",
" def __init__(self):\n",
" super(SSD, self).__init__()\n",
" self.fm_dims = fm_dims\n",
" self.boxes_scale = boxes_scale\n",
" self.aspect_ratios = aspect_ratios\n",
" self.base = VGG16()\n",
" self.auc = AuxiliaryConv()\n",
" self.prediction = Prediction()\n",
" self.rescale = nn.Parameter(torch.FloatTensor(1, 512, 1, 1))#对conv4_3进行缩放低特征提取器的规模更大使用L2 norm归一化和其他层保持一个scale\n",
" nn.init.constant_(self.rescale, 20)#全是20后面通过反向传播训练nn.Parameter是可以更新的\n",
" \n",
" def forward(self image):\n",
" conv4_3, conv7 = self.base(image)#(N, 512, 38, 38),(N, 1024, 19, 19)\n",
" norm = conv4_3.pow(2).sum(dim=1, keepdim=True).sqrt()#(N, 1, 38, 38)\n",
" conv4_3 = conv4_3 / norm #(N, 512, 38, 38)\n",
" conv4_3 = conv4_3 * self.rescale\n",
" \n",
" conv8_2, conv_9_2, conv10_2, conv11_2 = self.auc(conv_7)#(N, 512, 10, 10),(N, 256, 5, 5),(N, 256, 3, 3), (N, 256, 1, 1)\n",
" locations, classes = self.prediction(conv4_3, conv7, conv8_2, conv9_2, conv10_2, conv11_2)\n",
" return locations, classes\n",
" \n",
" def get_prior_boxes(self):\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b5cc276",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}