下载Person_reID_baseline_pytorch地址:
https://github.com/layumi/Person_reID_baseline_pytorch/tree/master/tutorial
下载Market1501数据集:
http://www.liangzheng.org/Project/project_reid.html
Market1501数据集结构:
├── Market/
│ ├── bounding_box_test/ /* Files for testing (candidate images pool)
│ ├── bounding_box_train/ /* Files for training
│ ├── gt_bbox/ /* We do not use it
│ ├── gt_query/ /* Files for multiple query testing
│ ├── query/ /* Files for testing (query images)
│ ├── readme.txt
修改
--test_dir
路径,执行
python prepare.py
之后的数据集结构:
├── Market/
│ ├── bounding_box_test/ /* Files for testing (candidate images pool)
│ ├── bounding_box_train/ /* Files for training
│ ├── gt_bbox/ /* We do not use it
│ ├── gt_query/ /* Files for multiple query testing
│ ├── query/ /* Files for testing (query images)
│ ├── readme.txt
│ ├── pytorch/
│ ├── train/ /* train
│ ├── 0002
| ├── 0007
| ...
│ ├── val/ /* val
│ ├── train_all/ /* train+val
│ ├── query/ /* query files
│ ├── gallery/ /* gallery files
训练模型并测试,修改
train.py、test.py
中的
--test_dir
路径
/home/hylink/eclipse-workspace/reID/Market/pytorch
:
python train.py
python test.py
python demo.py --query_index 777
效果展示:
修改test.py(将原gallery和query生成底库,改为只生成gallery底库)
from __future__ import print_function, division
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import time
import os
import scipy.io
from model import ft_net, ft_net_dense, PCB, PCB_test
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2')
parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last')
parser.add_argument('--test_dir',default='/home/hylink/eclipse-workspace/reID/Market/pytorch',type=str, help='./test_data')
parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path')
parser.add_argument('--batchsize', default=32, type=int, help='batchsize')
parser.add_argument('--use_dense', action='store_true', help='use densenet121' )
parser.add_argument('--PCB', action='store_true', help='use PCB' )
parser.add_argument('--multi', action='store_true', help='use multiple query' )
opt = parser.parse_args()
str_ids = opt.gpu_ids.split(',')
name = opt.name
test_dir = opt.test_dir
gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >=0:
gpu_ids.append(id)
if len(gpu_ids)>0:
torch.cuda.set_device(gpu_ids[0])
data_transforms = transforms.Compose([
transforms.Resize((288,144), interpolation=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
if opt.PCB:
data_transforms = transforms.Compose([
transforms.Resize((384,192), interpolation=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
data_dir = test_dir
if opt.multi:
image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query','multi-query']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
shuffle=False, num_workers=16) for x in ['gallery','query','multi-query']}
else:
image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
shuffle=False, num_workers=16) for x in ['gallery']}
use_gpu = torch.cuda.is_available()
def load_network(network):
save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch)
network.load_state_dict(torch.load(save_path))
return network
def fliplr(img):
'''flip horizontal'''
inv_idx = torch.arange(img.size(3)-1,-1,-1).long()
img_flip = img.index_select(3,inv_idx)
return img_flip
def extract_feature(model,dataloaders):
features = torch.FloatTensor()
count = 0
for data in dataloaders:
img, label = data
n, c, h, w = img.size()
count += n
print(count)
if opt.use_dense:
ff = torch.FloatTensor(n,1024).zero_()
else:
ff = torch.FloatTensor(n,2048).zero_()
if opt.PCB:
ff = torch.FloatTensor(n,2048,6).zero_()
for i in range(2):
if(i==1):
img = fliplr(img)
input_img = Variable(img.cuda())
outputs = model(input_img)
f = outputs.data.cpu()
ff = ff+f
if opt.PCB:
fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(6)
ff = ff.div(fnorm.expand_as(ff))
ff = ff.view(ff.size(0), -1)
else:
fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
ff = ff.div(fnorm.expand_as(ff))
features = torch.cat((features,ff), 0)
return features
def get_id(img_path):
camera_id = []
labels = []
for path, v in img_path:
filename = os.path.basename(path)
label = filename[0:4]
camera = filename.split('c')[1]
if label[0:2]=='-1':
labels.append(-1)
else:
labels.append(int(label))
camera_id.append(int(camera[0]))
return camera_id, labels
gallery_path = image_datasets['gallery'].imgs
gallery_cam,gallery_label = get_id(gallery_path)
if opt.multi:
mquery_path = image_datasets['multi-query'].imgs
mquery_cam,mquery_label = get_id(mquery_path)
print('-------test-----------')
if opt.use_dense:
model_structure = ft_net_dense(751)
else:
model_structure = ft_net(751)
if opt.PCB:
model_structure = PCB(751)
model = load_network(model_structure)
if not opt.PCB:
model.model.fc = nn.Sequential()
model.classifier = nn.Sequential()
else:
model = PCB_test(model)
model = model.eval()
if use_gpu:
model = model.cuda()
gallery_feature = extract_feature(model,dataloaders['gallery'])
if opt.multi:
mquery_feature = extract_feature(model,dataloaders['multi-query'])
result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam}
scipy.io.savemat('pytorch_result.mat',result)
if opt.multi:
result = {'mquery_f':mquery_feature.numpy(),'mquery_label':mquery_label,'mquery_cam':mquery_cam}
scipy.io.savemat('multi_query.mat',result)
修改demo.py(将query路径下的图片生成特征并于gallery底库进行比对并展示)
from __future__ import print_function, division
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import time
import os
import scipy.io
import matplotlib.pyplot as plt
from model import ft_net, ft_net_dense, PCB, PCB_test
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2')
parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last')
parser.add_argument('--test_dir',default='/home/hylink/eclipse-workspace/reID/Market/pytorch',type=str, help='./test_data')
parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path')
parser.add_argument('--batchsize', default=32, type=int, help='batchsize')
parser.add_argument('--use_dense', action='store_true', help='use densenet121' )
parser.add_argument('--PCB', action='store_true', help='use PCB' )
parser.add_argument('--multi', action='store_true', help='use multiple query' )
parser.add_argument('--query_index', default=3, type=int, help='test_image_index')
opt = parser.parse_args()
str_ids = opt.gpu_ids.split(',')
name = opt.name
test_dir = opt.test_dir
gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >=0:
gpu_ids.append(id)
if len(gpu_ids)>0:
torch.cuda.set_device(gpu_ids[0])
data_transforms = transforms.Compose([
transforms.Resize((288,144), interpolation=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
if opt.PCB:
data_transforms = transforms.Compose([
transforms.Resize((384,192), interpolation=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
data_dir = test_dir
if opt.multi:
image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query','multi-query']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
shuffle=False, num_workers=16) for x in ['gallery','query','multi-query']}
else:
image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
shuffle=False, num_workers=16) for x in ['gallery','query']}
class_names = image_datasets['query'].classes
use_gpu = torch.cuda.is_available()
def load_network(network):
save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch)
network.load_state_dict(torch.load(save_path))
return network
def fliplr(img):
'''flip horizontal'''
inv_idx = torch.arange(img.size(3)-1,-1,-1).long()
img_flip = img.index_select(3,inv_idx)
return img_flip
def extract_feature(model,dataloaders):
features = torch.FloatTensor()
count = 0
for data in dataloaders:
img,