随着PaddlePaddle2.0的更新,PaddleClas图像分类套件也更新到了2.0-rc1版本。新版本的PaddleClas套件已经默认使用动态图来进行模型训练。现在我们使用PaddleClas套件从零开始实现一个简单的垃圾分类器。来体验一下新版本的PaddleClas的的方便快捷,即使初学者也能快速的训练出高精度的模型。本篇文章分为上下两部分,上部讲解如何从零开始训练,下部讲解部分核心代码以及深度学习训练过程中使用到的技术。
1.准备数据集
数据集下载地址:
https://aistudio.baidu.com/aistudio/datasetdetail/64185
下载好数据集之后,首先需要解压压缩包。
1 | mkdir dataset |
数据集中共包含43个分类,例如:9代表”厨余垃圾/水果果肉”、22代表”可回收物/旧衣服、39代表有害垃圾/过期药物”。具体类别可以查看garbage_classify中的garbage_classify_rule.json文件。
有了数据集之后,需要对数据集进行划分。在dataset目录下创建process_dataset.py文件,使用下列代码将数据集划分为训练集、验证集和测试集,划分比例为8:1:1。
1 | import os |
以上代码运行结束后,目录结构如下:
1 | ├── garbage_classify |
2.下载PaddleClas套件
下载PaddleClas源代码,并切换到2.0-rc1版本。安装该套件依赖软件可参考以下文档:
https://github.com/PaddlePaddle/PaddleClas/blob/release/2.0-rc1/docs/en/tutorials/install_en.md1
2
3git clone https://github.com/PaddlePaddle/PaddleClas.git
git fetch
git branch release/2.0-rc1 origin/release/2.0-rc1
3.修改配置文件
PaddleClas套件中包含了多种神经网络模型,也包含了模型对应的训练参数,配置参数保存在configs路径下。本次的垃圾分类器我选择一个工业界常用的ResNet50网络作为分类器。首先通过拷贝的方式新建一个垃圾分类器的配置文件。1
2cd PaddleClas/configs/ResNet/
cp ResNet50_vd.yaml garbage_ResNet50_vd.yaml
然后修改garbage_ResNet50_vd.yaml内容如下:
1 | mode: 'train' |
4.开始训练
为了加快模型的收敛,同时提升模型的精度,这里我选择先加载预训练模型,然后对模型进行微调。首先需要下载预训练权重。
1 | wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_pretrained.pdparams |
然后开始训练模型:
1 | python tools/train.py \ |
训练过程中输入日志如下:
1 | W1214 20:29:28.872682 1473 device_context.cc:338] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 10.1, Runtime API Version: 10.1 |
5.模型评估
为了可以快速的看到效果,训练100个epoch之后,可以先停止训练。当前最优模型在验证集上的精度为top1: 0.90589, top5: 0.98966。
然后我们在测试集上评估一下最优模型的精度。
将PaddleClas/configs/ResNet/garbage_ResNet50_vd.yaml文件中验证集的路径改为测试集。
1 | VALID: |
开始评估模型,
1 | python tools/eval.py -c \ |
运行结果如下:
1 | 2020-12-15 09:08:25 INFO: epoch:0 , valid step:0 , loss: 1.05716, top1: 0.89062, top5: 1.00000, lr: 0.000000, batch_cost: 0.75766 s, reader_cost: 0.68446 s, ips: 84.47009 images/sec. |
可以看出当前的最优模型在测试集上的精度为top1: 0.90331, top5: 0.99018。准确率可以达到90%,当然这个精度还是可以继续提升的。可以通过调参、更换模型和数据增强进一步提升模型精度。
下一篇会解析一下PaddleClas套件中的核心代码,以及一些调优的策略。
PaddleClas仓库地址:https://github.com/PaddlePaddle/PaddleClas