行人重试别PCB
2019-07-14 06:25发布
生成海报
行人重试别PCB 未使用默认标题
行人重识别PCB
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from .resnet import resnet50
class PCBModel(nn.Module):
def __init__(
self,
last_conv_stride=1,
last_conv_dilation=1,
num_stripes=6,
local_conv_out_channels=256,
num_classes=0
):
super(PCBModel, self).__init__()
#通过去掉全连接层和分类层的resnet后得到的 base feature map
self.base = resnet50(
pretrained=True,
last_conv_stride=last_conv_stride,
last_conv_dilation=last_conv_dilation)
#将feature map 分成多个 stripes
self.num_stripes = num_stripes
self.local_conv_list = nn.ModuleList()
for _ in range(num_stripes):
self.local_conv_list.append(nn.Sequential(
# 进入的channel个数为 2048 输出的channel数为256(local_conv_out_channels=256)
nn.Conv2d(2048, local_conv_out_channels, 1),
nn.BatchNorm2d(local_conv_out_channels),
nn.ReLU(inplace=True)
))
if num_classes > 0:
self.fc_list = nn.ModuleList()
# 构造num_stripes 个分类器 分别连接到经过local_conv_list之后的256通道的feature map后
for _ in range(num_stripes):
fc = nn.Linear(local_conv_out_channels, num_classes)
init.normal(fc.weight, std=0.001)
init.constant(fc.bias, 0)
self.fc_list.append(fc)
def forward(self, x):
"""
Returns:
local_feat_list: each member with shape [N, c]
logits_list: each member with shape [N, num_classes]
"""
# shape [N, C, H, W]
feat = self.base(x)
assert feat.size(2) % self.num_stripes == 0
stripe_h = int(feat.size(2) / self.num_stripes)
# 特征向量
local_feat_list = []
# 分类向量 由特征向量通过全连接得到
logits_list = []
for i in range(self.num_stripes):
# shape [N, C, 1, 1]
# 把每一个channel上的part区域做平均,得到图中的feature vector V
local_feat = F.avg_pool2d(
feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],
(stripe_h, feat.size(-1)))
# shape [N, c, 1, 1]
#得到图中的feature vector G
local_feat = self.local_conv_list[i](local_feat)
# shape [N, c]
# resize feature map 到[N, C]
local_feat = local_feat.view(local_feat.size(0), -1)
local_feat_list.append(local_feat)
# 如果有全连接层 那么就让 local_feat 通过fc层
if hasattr(self, 'fc_list'):
logits_list.append(self.fc_list[i](local_feat))
#如果进行分类 那么就 return 分类完的特征
if hasattr(self, 'fc_list'):
return local_feat_list, logits_list
#如果不分类 仅进行metric learning 那么就return local_feat_list
return local_feat_list
打开微信“扫一扫”,打开网页后点击屏幕右上角分享按钮