环境
- ubuntu 18.04 64bit
- torch 1.7.1+cu101
数据集准备
这里以前面我们进行 YOLOv5 模型训练 时用到的口罩数据集为例,这个数据集来自网站 roboflow.com
,再次安利一下这个站点,真的非常棒,不止有详尽的博客教程,还有很多的开放数据集,而且支持的数据格式也很丰富,绝对值得经常去逛一逛。
口罩数据集下载地址:https://public.roboflow.com/object-detection/mask-wearing/4,这里也在百度网盘上存一份,需要的自取
链接:https://pan.baidu.com/s/1JvniT205zX79wASqiKtt5Q
提取码:9feh
数据集下载后解压,将文件夹重命名为 mask
,并放到 yolov7
的根目录下(这里可以随意,只要前后的路径匹配上就可以了),完整的目录结构如下
可以发现,其实 yolov7
和 yolov5
数据集格式是一模一样的,通过标注工具 labelimg
也可以得出同样的结论
训练
模型训练开始之前,我们需要修改部分配置文件
首先是数据集中的 data.yaml
train: mask/train/images
val: mask/valid/images
nc: 2
names: ['mask', 'no-mask']
其次,是修改 yolov7
中的配置文件 cfg/training/yolov7.yaml
,主要是 nc
这个字段
nc: 2 # number of classes
然后,就可以进行训练了,执行
python train.py --data mask/data.yaml --cfg cfg/training/yolov7.yaml --weights '' --name yolov7 --hyp data/hyp.scratch.p5.yaml
模型测试
最后,我们来测试下模型的效果
python detect.py --source mask/test/images/shutterstock_1627199179_jpg.rf.350e69105dd1458572a590c3e3ef2538.jpg --weight runs/train/yolov7/weights/best.pt
python detect.py --source mask/test/images/the-first-day-of-wuhan-s-closure-some-people-fled-some-panicked_jpg.rf.51ed69bf8d327d93b429a08581f6dea0.jpg --weight runs/train/yolov7/weights/best.pt