行人重试别PCB

2019-07-14 06:25发布

行人重试别PCB 未使用默认标题

行人重识别PCB import torch import torch.nn as nn import torch.nn.init as init import torch.nn.functional as F from .resnet import resnet50 class PCBModel(nn.Module): def __init__( self, last_conv_stride=1, last_conv_dilation=1, num_stripes=6, local_conv_out_channels=256, num_classes=0 ): super(PCBModel, self).__init__() #通过去掉全连接层和分类层的resnet后得到的 base feature map self.base = resnet50( pretrained=True, last_conv_stride=last_conv_stride, last_conv_dilation=last_conv_dilation) #将feature map 分成多个 stripes self.num_stripes = num_stripes self.local_conv_list = nn.ModuleList() for _ in range(num_stripes): self.local_conv_list.append(nn.Sequential( # 进入的channel个数为 2048 输出的channel数为256(local_conv_out_channels=256) nn.Conv2d(2048, local_conv_out_channels, 1), nn.BatchNorm2d(local_conv_out_channels), nn.ReLU(inplace=True) )) if num_classes > 0: self.fc_list = nn.ModuleList() # 构造num_stripes 个分类器 分别连接到经过local_conv_list之后的256通道的feature map后 for _ in range(num_stripes): fc = nn.Linear(local_conv_out_channels, num_classes) init.normal(fc.weight, std=0.001) init.constant(fc.bias, 0) self.fc_list.append(fc) PCB网络framework def forward(self, x): """ Returns: local_feat_list: each member with shape [N, c] logits_list: each member with shape [N, num_classes] """ # shape [N, C, H, W] feat = self.base(x) assert feat.size(2) % self.num_stripes == 0 stripe_h = int(feat.size(2) / self.num_stripes) # 特征向量 local_feat_list = [] # 分类向量 由特征向量通过全连接得到 logits_list = [] for i in range(self.num_stripes): # shape [N, C, 1, 1] # 把每一个channel上的part区域做平均,得到图中的feature vector V local_feat = F.avg_pool2d( feat[:, :, i * stripe_h: (i + 1) * stripe_h, :], (stripe_h, feat.size(-1))) # shape [N, c, 1, 1] #得到图中的feature vector G local_feat = self.local_conv_list[i](local_feat) # shape [N, c] # resize feature map 到[N, C] local_feat = local_feat.view(local_feat.size(0), -1) local_feat_list.append(local_feat) # 如果有全连接层 那么就让 local_feat 通过fc层 if hasattr(self, 'fc_list'): logits_list.append(self.fc_list[i](local_feat)) #如果进行分类 那么就 return 分类完的特征 if hasattr(self, 'fc_list'): return local_feat_list, logits_list #如果不分类 仅进行metric learning 那么就return local_feat_list return local_feat_list