PAI平台提供图像多标签分类相关算法,支持千万级别超大规模的图片样本训练。本文为您介绍如何使用PAI命令基于图片数据生成图像多标签分类模型。
图像分类训练
您可以使用SQL脚本组件进行PAI命令调用,也可以使用MaxCompute客户端或DataWorks的开发节点进行PAI命令调用。如何使用MaxCompute客户端和创建DataWorks的开发节点,详情请参见使用本地客户端(odpscmd)连接或创建并管理MaxCompute节点。
图像单标签分类单机训练
pai -name easy_vision_ext -Dbuckets='oss://{bucket_name}.{oss_host}/{path}' -Darn='acs:ram::*********:role/aliyunodpspaidefaultrole' -DossHost='{oss_host}' -DgpuRequired=100 -Dcmd train -Dparam_config '--model_type Classification --backbone inception_v4 --num_classes 10 --num_epochs 1 --model_dir oss://examplebucket/test/cifar_inception_v4 --use_pretrained_model true --train_data oss://examplebucket/data/test/cifar10/*.tfrecord --test_data oss://examplebucket/data/test/cifar10/*.tfrecord --num_test_example 20 --train_batch_size 32 --test_batch_size=32 --image_size 299 --initial_learning_rate 0.01 --staircase true'
图像单标签分类多机训练
pai -name easy_vision_ext -Dbuckets='oss://{bucket_name}.{oss_host}/{path}' -Darn='acs:ram::*********:role/aliyunodpspaidefaultrole' -DossHost='{oss_host}' -Dcmd train -Dcluster='{ \"ps\": { \"count\" : 1, \"cpu\" : 600 }, \"worker\" : { \"count\" : 3, \"cpu\" : 800, \"gpu\" : 100 } }' -Dparam_config='--model_type Classification --backbone inception_v4 --num_classes 10 --num_epochs 1 --model_dir oss://examplebucket/test/cifar_inception_v4_dis --use_pretrained_model true --train_data oss://examplebucket/data/test/cifar10/*.tfrecord --test_data oss://examplebucket/data/test/cifar10/*.tfrecord --num_test_example 20 --train_batch_size 32 --test_batch_size=32 --image_size 299 --initial_learning_rate 0.01 --staircase true'
图像多标签单机训练
pai -name easy_vision_ext -Dbuckets='oss://{bucket_name}.{oss_host}/{path}' -Darn='acs:ram::*********:role/aliyunodpspaidefaultrole' -DossHost='{oss_host}' -DgpuRequired=100 -Dcmd train -Dparam_config '--model_type MultiLabelClassification --backbone inception_v4 --num_classes 10 --num_epochs 1 --model_dir oss://examplebucket/test/cifar_inception_v4 --use_pretrained_model true --train_data oss://examplebucket/data/test/cifar10/*.tfrecord --test_data oss://examplebucket/data/test/cifar10/*.tfrecord --num_test_example 20 --train_batch_size 32 --test_batch_size=32 --image_size 299 --initial_learning_rate 0.01 --staircase true'
图像多标签多机训练
pai -name easy_vision_ext -Dbuckets='oss://{bucket_name}.{oss_host}/{path}' -Darn='acs:ram::*********:role/aliyunodpspaidefaultrole' -DossHost='{oss_host}' -Dcmd train -Dcluster='{ \"ps\": { \"count\" : 1, \"cpu\" : 600 }, \"worker\" : { \"count\" : 3, \"cpu\" : 800, \"gpu\" : 100 } }' -Dparam_config='--model_type MultiLabelClassification --backbone inception_v4 --num_classes 10 --num_epochs 1 --model_dir oss://examplebucket/test/cifar_inception_v4_dis --use_pretrained_model true --train_data oss://examplebucket/data/test/cifar10/*.tfrecord --test_data oss://examplebucket/data/test/cifar10/*.tfrecord --num_test_example 20 --train_batch_size 32 --test_batch_size=32 --image_size 299 --initial_learning_rate 0.01 --staircase true'
参数说明
参数 | 是否必选 | 描述 | 取值格式 | 默认值 |
buckets | 是 | OSS Bucket地址。Bucket必须以正斜线(/)结尾。 | oss://{bucket_name}.{oss_host}/{path} | 无 |
arn | 是 | 访问OSS的授权。您可以登录PAI控制台,在全部产品依赖页面的Designer区域,单击操作列下的查看授权信息,获取arn,具体操作请参见云产品依赖与授权:Designer。 | acs:ram::*:role/AliyunODPSPAIDefaultRole | 无 |
ossHost | 否 | OSS访问域名,详情请参见访问域名和数据中心。如果未指定该参数,则从Buckets参数中获取。 | oss-{region}.aliyuncs.com | 从Buckets参数中获取 |
cluster | 否 | 分布式训练参数相关配置。 | JSON格式字符串 | “” |
gpuRequired | 否 | 标识是否使用GPU,默认使用一张卡。如果取值200,则一个Worker申请2张卡。 | 100 | 100 |
cmd | 是 | EasyVision任务类型。模型训练时,该参数应取值为train。 | train | 无 |
param_config | 是 | 模型训练参数,其格式与Python Argparser参数格式一致,详情请参见param_config说明。 | STRING | 无 |
param_config说明
param_config包含若干模型配置相关参数,格式为Python Argparser,示例如下。
-Dparam_config = '--model_type MultiLabelClassification --backbone inception_v4 --num_classes 200 --model_dir oss://your/bucket/exp_dir'
所有字符串类型的参数,其取值均不加引号。
参数名称 | 是否必选 | 参数描述 | 取值格式 | 默认值 |
model_type | 是 | 训练模型类型。多标签分类的模型类型为MultiLabelClassification。 | STRING | 无 |
backbone | 否 | 识别模型的网络名称,取值包括:
| STRING | inception_v4 |
num_classes | 是 | 分类类别数量。 | 100 | 无 |
image_size | 否 | 图片Resize后的大小,单位为像素。 | INT | 224 |
use_crop | 否 | 是否使用crop进行数据增强。 | BOOL | true |
eval_each_category | 否 | 是否针对每个类别单独进行评估。 | BOOL | false |
optimizer | 否 | 优化方法,取值包括:
| STRING | momentum |
lr_type | 否 | 学习率调整策略,取值包括:
| STRING | exponential_decay |
initial_learning_rate | 否 | 初始学习率。 | FLOAT | 0.01 |
decay_epochs | 否 | 如果使用exponential_decay,该参数对应tf.train.exponential.decay中的decay_steps,系统会自动根据训练数据总数将decay_epochs转换为decay_steps。例如,取值为10,通常是总Epoch数的1/2。 如果使用manual_step,该参数表示需要调整学习率的迭代轮数。例如16 18表示在16 Epoch和18 Epoch对学习率进行调整。通常将这两个值配置为总Epoch的8/10和9/10。 | 整数列表,例如20 20 40 60。 | 20 |
decay_factor | 否 | tf.train.exponential.decay中的decay_factor。 | FLOAT | 0.95 |
staircase | 否 | tf.train.exponential.decay中的staircase。 | BOOL | true |
power | 否 | tf.train.polynomial.decay中的power。 | FLOAT | 0.9 |
learning_rates | 否 | manual_step学习率调整策略中使用的参数,表示在指定Epoch中学习率的取值。 如果您指定的调整Epoch有两个,则需要在此指定两个Epoch对应的学习率。例如,如果decay_epochs为20 40,则该将参数配置为0.001 0.0001,表示在20 Epoch学习率调整为0.001,40 Epoch学习率调整为0.0001。建议几次调整的学习率依次为初始学习率的1/10、1/100及1/1000。 | 浮点列表 | 无 |
train_data | 是 | 训练数据文件的OSS路径。 | oss://path/to/train_*.tfrecord | 无 |
test_data | 是 | 训练过程中,评估数据的OSS路径。 | oss://path/to/test_*.tfrecord | 无 |
train_batch_size | 是 | 训练的batch_size。 | INT,例如32。 | 无 |
test_batch_size | 是 | 评估的batch_size。 | INT,例如32。 | 无 |
train_num_readers | 否 | 训练数据并发读取线程数。 | INT | 4 |
model_dir | 是 | 训练的OSS目录。 | oss://path/to/model | 无 |
pretrained_model | 否 | 预训练模型OSS路径。如果指定该路径,则在该模型基础上进行微调。 | oss://pai-vision-data-sh/pretrained_models/inception_v4.ckpt | “” |
use_pretrained_model | 否 | 是否使用预训练模型。 | BOOL | true |
num_epochs | 是 | 训练迭代次数。取值1表示对所有训练数据都进行一次迭代。 | INT,例如40。 | 无 |
num_test_example | 否 | 训练过程中评估数据条目数。取值 -1表示使用所有测试数据作为评估数据。 | INT,例如2000。 | -1 |
num_visualizations | 否 | 评估过程可视化显示的样本数量。 | INT | 10 |
save_checkpoint_epochs | 否 | 保存Checkpoint的频率,以Epoch为单位。取值为1表示每完成一次训练就保存一次Checkpoint。 | INT | 1 |
num_train_images | 否 | 总的训练样本数。如果使用自己生成的TFRecord,则需要指定该参数。 | INT | 0 |
label_map_path | 否 | 类别映射文件。如果使用自己生成的TFRecord,则需要指定该参数。 | STRING | ”” |
相关文档
与图像分类模型不同,多标签分类的多个类别并不互斥,图像多标签分类模型会输出识别概率达到一定阈值的所有类别。您可以将生成的模型部署至EAS,详情请参见服务部署:控制台。