加微信进交流群:xituxiaoshutong100

Market 1501数据集及在DeepSort中的训练

PyTorch 迷途小书童 4评论

软硬件环境

  • ubuntu 18.04 64bit
  • GTX 1070Ti
  • anaconda with python 3.7
  • pytorch 1.6
  • cuda 10.1

Market 1501数据集

Market-1501数据集是在清华大学校园中采集,在夏天拍摄,于2015年构建并公开。它包括由6个摄像头(其中5个高清摄像头和1个低分辨率摄像头)拍摄到的1501个行人、32668个检测到的行人矩形框。每个行人至少有2个摄像头捕捉到,并且在一个摄像头中可能具有多张图像。训练集有751人,包含12936张图像,平均每个人有17.2张训练数据;测试集有750人,包含19732张图像,平均每个人有26.3张测试数据。3368张查询图像的行人检测矩形框是人工绘制的,而gallery中的行人检测矩形框则是使用DPM检测器检测得到的。

数据集目录结构

Market-1501-v15.09.15
├── bounding_box_test
├── bounding_box_train
├── gt_bbox
├── gt_query
├── query
└── readme.txt

包含四个文件夹

  • bounding_box_test: 用于测试
  • bounding_box_train: 用于训练
  • query: 有750个身份。我们为每个摄像机随机选择一个查询图像
  • gt_query: 包含实际标注
  • gt_bbox: 手绘边框,主要用于判断DPM边界框是否良好

图片命名规则

0001_c1s1_000151_01.jpg为例

  • 0001表示每个人的标签编号,从0001到1501,共有1501个人
  • c1表示第一个摄像头(ccamera),共有6个摄像头
  • s1 表示第一个录像片段(ssequence),每个摄像机都有多个录像片段
  • 000151表示c1s1的第000151帧图片,视频帧率fps为25
  • 01表示c1s1_001051这一帧上的第1个检测框,由于采用DPM自动检测器,每一帧上的行人可能会有多个,相应的标注框也会有多个。00则表示手工标注框

数据集下载地址:

链接:https://pan.baidu.com/s/1i9aiZx-EC3fjhn3uWTKZjw
提取码:up8x

deepsort模型训练

前文 《基于YOLOv5和DeepSort的目标跟踪https://xugaoxiang.com/2020/10/17/yolov5-deepsort-pytorch/ 介绍过利用YOLOv5DeepSort来实现目标的检测及跟踪。现在我们使用Market 1501数据集来训练跟踪器模型。

至于YOLOv5检测模型的训练,参考前面的博文 YOLOv5模型训练。我们使用原作者提供的yolov5s.pt就可行。

依赖环境就不再说了,参考前文

git clone --recurse-submodules https://github.com/mikel-brostrom/Yolov5_DeepSort_Pytorch.git
cd Yolov5_DeepSort_Pytorch/deep_sort/deep_sort/deep

接下来将数据集Market拷贝到Yolov5_DeepSort_Pytorch/deep_sort/deep_sort/deep下然后解压,数据集存放的位置是随意的,可以通过参数--data-dir指定。

针对原项目中的训练代码train.py,需要做点修改

train_dir = os.path.join(root,"train")
test_dir = os.path.join(root,"test")

改成

train_dir = os.path.join(root,"")
test_dir = os.path.join(root,"")

然后将数据集中的文件夹bounding_box_train重命名为trainbounding_box_test重命名为test。不然的话,训练的时候就会报下面2个错

Traceback (most recent call last):
  File "train.py", line 43, in <module>
    torchvision.datasets.ImageFolder(train_dir, transform=transform_train),
  File "/home/xugaoxiang/anaconda3/envs/deepsort/lib/python3.7/site-packages/torchvision/datasets/folder.py", line 208, in __init__
    is_valid_file=is_valid_file)
  File "/home/xugaoxiang/anaconda3/envs/deepsort/lib/python3.7/site-packages/torchvision/datasets/folder.py", line 100, in __init__
    raise RuntimeError(msg)
RuntimeError: Found 0 files in subfolders of: data/train
Supported extensions are: .jpg,.jpeg,.png,.ppm,.bmp,.pgm,.tif,.tiff,.webp

Traceback (most recent call last):
  File "train.py", line 43, in <module>
    torchvision.datasets.ImageFolder(train_dir, transform=transform_train),
  File "/home/xugaoxiang/anaconda3/envs/deepsort/lib/python3.7/site-packages/torchvision/datasets/folder.py", line 208, in __init__
    is_valid_file=is_valid_file)
  File "/home/xugaoxiang/anaconda3/envs/deepsort/lib/python3.7/site-packages/torchvision/datasets/folder.py", line 94, in __init__
    classes, class_to_idx = self._find_classes(self.root)
  File "/home/xugaoxiang/anaconda3/envs/deepsort/lib/python3.7/site-packages/torchvision/datasets/folder.py", line 123, in _find_classes
    classes = [d.name for d in os.scandir(dir) if d.is_dir()]
FileNotFoundError: [Errno 2] No such file or directory: 'Market-1501-v15.09.15/train'

最后,还需要改个地方,编辑model.py,将

def __init__(self, num_classes=751 ,reid=False):

改成

def __init__(self, num_classes=5 ,reid=False):

然后就可以开始训练了

python train.py --data-dir Market-1501-v15.09.15

deepsort_market_pytorch

deepsort_market_pytorch

训练结束后,会在checkpoint下生成模型文件ckpt.t7,找个视频,测试一下

补充

Market到底怎么组织

原始的数据集结构是这样的

Market-1501-v15.09.15
├── bounding_box_test
├── bounding_box_train
├── gt_bbox
├── gt_query
├── query
└── readme.txt

bounding_box_trainbounding_box_test目录下就是具体的图片文件了,这里面并没有体现id。正确的做法是:将某个id(也就是某个人)的图片放在一个文件夹内,且以该id作为文件夹的名称。如将bounding_box_train下所有以0002开头的图片文件存放在文件夹0002

market deepsort

bounding_box_test下的图片处理也是一样,test中有个id是-1,嗯?真没弄懂,但不影响训练。

market deepsort

针对上述的操作,写了个简单的脚本

import os
import sys
import shutil

if __name__ == '__main__':
    root = os.path.join(sys.argv[1], 'dataset')
    os.mkdir(root)
    train_dir = os.path.join(root, 'train')
    test_dir = os.path.join(root, 'test')
    os.mkdir(train_dir)
    os.mkdir(test_dir)

    # 处理train
    for file in os.listdir(os.path.join(sys.argv[1], 'bounding_box_train')):
        print(file)
        id = file.split('_')[0]
        if not os.path.exists(os.path.join(train_dir, id)):
            os.mkdir(os.path.join(train_dir, id))
        else:
            shutil.copy(os.path.join(sys.argv[1], 'bounding_box_train', file), os.path.join(train_dir, id))

    # 处理test
    for file in os.listdir(os.path.join(sys.argv[1], 'bounding_box_test')):
        id = file.split('_')[0]
        if not os.path.exists(os.path.join(test_dir, id)):
            os.mkdir(os.path.join(test_dir, id))
        else:
            shutil.copy(os.path.join(sys.argv[1], 'bounding_box_test', file), os.path.join(test_dir, id))

使用方法

python test.py Market-1501-v15.09.15

脚本执行结束后,会在Market-1501-v15.09.15下生成文件夹dataset,文件结构是这样的

dataset/
├── train
    ├── 0002
    ├── 0007
    ├── 0010
    ├── 0011
    ├── 0012
    ├── 0020
    ├── 0022

├── test
    ├── 0000
    ├── 0001
    ├── 0003
    ├── 0004
    ├── 0005
    ├── 0006
    ├── 0008
    ├── 0009

这样就生成了一份可直接训练的数据集,而原有的也不会被破坏。

num_classes含义

这里解释下num_classes的含义,根据原工程 https://github.com/mikel-brostrom/Yolov5_DeepSort_Pytorchtrain.py的代码

trainloader = torch.utils.data.DataLoader(
    torchvision.datasets.ImageFolder(train_dir, transform=transform_train),
    batch_size=64,shuffle=True
)
testloader = torch.utils.data.DataLoader(
    torchvision.datasets.ImageFolder(test_dir, transform=transform_test),
    batch_size=64,shuffle=True
)

num_classes = max(len(trainloader.dataset.classes), len(testloader.dataset.classes))

可以看到,num_classestraintest集合中类型(也就是总id数)数目较大者的值,在Market 1501数据集中,train中有751个,test中有752(包括了一个id号为-1)个,因此,num_classes就是572。

因此,在训练数据集的时候只需要修改model.py,将num_classes改成572,train.py无需修改。

参考资料

喜欢 (0)
发表我的评论
取消评论

表情
(4)个小伙伴在吐槽
  1. 我是在GoogleColab 上运行的
    匿名2021-04-15 11:51 (1天前)回复
  2. 博主你好,能帮忙解决一下在运行python track.py --source 文件.mp4 时出现提示cannot connect to X server ?
    匿名2021-04-13 13:08 (3天前)回复
    • 怎么会用到x server的?你在什么平台上跑的?
      迷途小书童2021-04-13 20:24 (3天前)回复
  3. 我发现deep里面的feature_extract.py要在Extract类里的transforms里加一句通道数的变换才能运行,不然就报维度错误啊
    匿名2021-03-18 17:07 回复