行人重识别(ReID) ——基于Person_reID_baseline_pytorch修改业务流程

2019-07-14 11:54发布

下载Person_reID_baseline_pytorch地址:https://github.com/layumi/Person_reID_baseline_pytorch/tree/master/tutorial
下载Market1501数据集:http://www.liangzheng.org/Project/project_reid.html
Market1501数据集结构: ├── Market/ │ ├── bounding_box_test/ /* Files for testing (candidate images pool) │ ├── bounding_box_train/ /* Files for training │ ├── gt_bbox/ /* We do not use it │ ├── gt_query/ /* Files for multiple query testing │ ├── query/ /* Files for testing (query images) │ ├── readme.txt 修改--test_dir路径,执行python prepare.py之后的数据集结构: ├── Market/ │ ├── bounding_box_test/ /* Files for testing (candidate images pool) │ ├── bounding_box_train/ /* Files for training │ ├── gt_bbox/ /* We do not use it │ ├── gt_query/ /* Files for multiple query testing │ ├── query/ /* Files for testing (query images) │ ├── readme.txt │ ├── pytorch/ │ ├── train/ /* train │ ├── 0002 | ├── 0007 | ... │ ├── val/ /* val │ ├── train_all/ /* train+val │ ├── query/ /* query files │ ├── gallery/ /* gallery files 训练模型并测试,修改train.py、test.py中的--test_dir路径/home/hylink/eclipse-workspace/reID/Market/pytorchpython train.py python test.py python demo.py --query_index 777 效果展示:
在这里插入图片描述 修改test.py(将原gallery和query生成底库,改为只生成gallery底库) # -*- coding: utf-8 -*- from __future__ import print_function, division import argparse import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable import numpy as np import torchvision from torchvision import datasets, models, transforms import time import os import scipy.io from model import ft_net, ft_net_dense, PCB, PCB_test ###################################################################### # Options # -------- parser = argparse.ArgumentParser(description='Training') parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last') parser.add_argument('--test_dir',default='/home/hylink/eclipse-workspace/reID/Market/pytorch',type=str, help='./test_data') parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path') parser.add_argument('--batchsize', default=32, type=int, help='batchsize') parser.add_argument('--use_dense', action='store_true', help='use densenet121' ) parser.add_argument('--PCB', action='store_true', help='use PCB' ) parser.add_argument('--multi', action='store_true', help='use multiple query' ) opt = parser.parse_args() str_ids = opt.gpu_ids.split(',') #which_epoch = opt.which_epoch name = opt.name test_dir = opt.test_dir gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >=0: gpu_ids.append(id) # set gpu ids if len(gpu_ids)>0: torch.cuda.set_device(gpu_ids[0]) ###################################################################### # Load Data # --------- # # We will use torchvision and torch.utils.data packages for loading the # data. # data_transforms = transforms.Compose([ transforms.Resize((288,144), interpolation=3), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ############### Ten Crop #transforms.TenCrop(224), #transforms.Lambda(lambda crops: torch.stack( # [transforms.ToTensor()(crop) # for crop in crops] # )), #transforms.Lambda(lambda crops: torch.stack( # [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop) # for crop in crops] # )) ]) if opt.PCB: data_transforms = transforms.Compose([ transforms.Resize((384,192), interpolation=3), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) data_dir = test_dir if opt.multi: image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query','multi-query']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, shuffle=False, num_workers=16) for x in ['gallery','query','multi-query']} else: image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, shuffle=False, num_workers=16) for x in ['gallery']} #class_names = image_datasets['query'].classes use_gpu = torch.cuda.is_available() ###################################################################### # Load model #--------------------------- def load_network(network): save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch) network.load_state_dict(torch.load(save_path)) return network ###################################################################### # Extract feature # ---------------------- # # Extract feature from a trained model. # def fliplr(img): '''flip horizontal''' inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W img_flip = img.index_select(3,inv_idx) return img_flip def extract_feature(model,dataloaders): features = torch.FloatTensor() count = 0 for data in dataloaders: img, label = data n, c, h, w = img.size() count += n print(count) if opt.use_dense: ff = torch.FloatTensor(n,1024).zero_() else: ff = torch.FloatTensor(n,2048).zero_() if opt.PCB: ff = torch.FloatTensor(n,2048,6).zero_() # we have six parts for i in range(2): if(i==1): img = fliplr(img) input_img = Variable(img.cuda()) outputs = model(input_img) f = outputs.data.cpu() ff = ff+f # norm feature if opt.PCB: # feature size (n,2048,6) # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature. # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6). fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(6) ff = ff.div(fnorm.expand_as(ff)) ff = ff.view(ff.size(0), -1) else: fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) ff = ff.div(fnorm.expand_as(ff)) features = torch.cat((features,ff), 0) return features def get_id(img_path): camera_id = [] labels = [] for path, v in img_path: #filename = path.split('/')[-1] filename = os.path.basename(path) label = filename[0:4] camera = filename.split('c')[1] if label[0:2]=='-1': labels.append(-1) else: labels.append(int(label)) camera_id.append(int(camera[0])) return camera_id, labels gallery_path = image_datasets['gallery'].imgs #query_path = image_datasets['query'].imgs gallery_cam,gallery_label = get_id(gallery_path) #query_cam,query_label = get_id(query_path) if opt.multi: mquery_path = image_datasets['multi-query'].imgs mquery_cam,mquery_label = get_id(mquery_path) ###################################################################### # Load Collected data Trained model print('-------test-----------') if opt.use_dense: model_structure = ft_net_dense(751) else: model_structure = ft_net(751) if opt.PCB: model_structure = PCB(751) model = load_network(model_structure) # Remove the final fc layer and classifier layer if not opt.PCB: model.model.fc = nn.Sequential() model.classifier = nn.Sequential() else: model = PCB_test(model) # Change to test mode model = model.eval() if use_gpu: model = model.cuda() # Extract feature gallery_feature = extract_feature(model,dataloaders['gallery']) #query_feature = extract_feature(model,dataloaders['query']) if opt.multi: mquery_feature = extract_feature(model,dataloaders['multi-query']) # Save to Matlab for check #result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam} result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam} scipy.io.savemat('pytorch_result.mat',result) if opt.multi: result = {'mquery_f':mquery_feature.numpy(),'mquery_label':mquery_label,'mquery_cam':mquery_cam} scipy.io.savemat('multi_query.mat',result) 修改demo.py(将query路径下的图片生成特征并于gallery底库进行比对并展示) # -*- coding: utf-8 -*- from __future__ import print_function, division import argparse import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable import numpy as np import torchvision from torchvision import datasets, models, transforms import time import os import scipy.io import matplotlib.pyplot as plt from model import ft_net, ft_net_dense, PCB, PCB_test ###################################################################### # Options # -------- parser = argparse.ArgumentParser(description='Training') parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last') parser.add_argument('--test_dir',default='/home/hylink/eclipse-workspace/reID/Market/pytorch',type=str, help='./test_data') parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path') parser.add_argument('--batchsize', default=32, type=int, help='batchsize') parser.add_argument('--use_dense', action='store_true', help='use densenet121' ) parser.add_argument('--PCB', action='store_true', help='use PCB' ) parser.add_argument('--multi', action='store_true', help='use multiple query' ) parser.add_argument('--query_index', default=3, type=int, help='test_image_index') opt = parser.parse_args() str_ids = opt.gpu_ids.split(',') #which_epoch = opt.which_epoch name = opt.name test_dir = opt.test_dir gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >=0: gpu_ids.append(id) # set gpu ids if len(gpu_ids)>0: torch.cuda.set_device(gpu_ids[0]) ###################################################################### # Load Data # --------- # # We will use torchvision and torch.utils.data packages for loading the # data. # data_transforms = transforms.Compose([ transforms.Resize((288,144), interpolation=3), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ############### Ten Crop #transforms.TenCrop(224), #transforms.Lambda(lambda crops: torch.stack( # [transforms.ToTensor()(crop) # for crop in crops] # )), #transforms.Lambda(lambda crops: torch.stack( # [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop) # for crop in crops] # )) ]) if opt.PCB: data_transforms = transforms.Compose([ transforms.Resize((384,192), interpolation=3), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) data_dir = test_dir if opt.multi: image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query','multi-query']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, shuffle=False, num_workers=16) for x in ['gallery','query','multi-query']} else: image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, shuffle=False, num_workers=16) for x in ['gallery','query']} class_names = image_datasets['query'].classes use_gpu = torch.cuda.is_available() ###################################################################### # Load model #--------------------------- def load_network(network): save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch) network.load_state_dict(torch.load(save_path)) return network ###################################################################### # Extract feature # ---------------------- # # Extract feature from a trained model. # def fliplr(img): '''flip horizontal''' inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W img_flip = img.index_select(3,inv_idx) return img_flip def extract_feature(model,dataloaders): features = torch.FloatTensor() count = 0 for data in dataloaders: img,