文本分类(MaxCompute)算法组件是以原始文本作为输入,输出类别标签。该算法组件集成了多种基于BERT的文本分类模型。本文介绍文本分类训练(MaxCompute)算法组件的配置方法及使用示例。
注意事项
该组件目前仅支持读取BOOLEAN、BIGINT、DOUBLE、STRING和DATETIME类型的数据。
算法简介
文本分类训练(MaxCompute)是通用的基于BERT的分类模型,输入文本数据,输出分类标签,适用于文本打标和文本情感分析等任务。模型结构如下图所示。

您可以通过以下两种方式使用文本分类(MaxCompute)算法:
- 在PAI-Designer中,通过可视化的方式配置组件参数,详情请参见下文的可视化配置参数。
文本分类(MaxCompute)算法组件位于组件库自然语言处理文件夹。
- 在DataWorks的ODPS SQL节点中,通过PAI命令的方式调用算法,详情请参见下文的PAI命令及说明。关于DataWorks的ODPS SQL节点请参见创建ODPS SQL节点。
前提条件
已开通OSS并完成授权,详情请参见开通OSS服务和PAI访问云产品授权:OSS。
使用限制
仅PAI-Designer提供该算法组件,使用前必须开通MaxCompute资源组,并且使用GPU。
可视化配置参数
- 输入桩
输入桩(从左到右) 数据类型 建议上游组件 是否必选 训练数据 MaxCompute表 读数据表 是 测试数据 MaxCompute表 读数据表 是 - 组件参数
页签 参数 是否必选 描述 默认值 字段设置 文本列选择 是 文本序列在输入表中对应的列名。 无 标签列选择 是 分类标签对应的列名。 无 标签枚举值 是 您需要枚举出所有标签,多个标签之间使用半角逗号(,)分隔。 无 样本权重列 否 样本加权列。每个样本的loss计算时,您可以加权重。 无 模型存储路径 是 模型Checkpoint的存储路径。例如 oss://pai-asr-test-bj/test_text_match3/
。无 参数设置 优化器类型 否 选择优化器类型,支持以下取值: - adam
- adagrad
- lamb
adam batchSize 否 特征提取的批大小,取值为INT类型。 32 sequenceLength 否 序列整体最大长度,取值范围为1~512。 128 numEpochs 否 训练的轮次,取值为INT类型。 3 学习率 否 优化器的学习率,取值为FLOAT类型。 2e-5 pretrainModelNameOrPath 否 选择预训练模型,例如: - pai-bert-base-zh
- pai-bert-small-zh
- pai-bert-tiny-zh
pai-bert-base-zh
模型额外参数 否 额外的参数。例如修改预训练模型,您可以将该参数配置为 pretrain_model_name_or_path=pai-bert-base-zh
。无
执行调优 指定Worker数 否 用于计算的Worker数量。 1 指定Worker的GPU卡数 否 每个Worker中的GPU卡数量。 1 指定Worker的CPU卡数 否 每个Worker中的CPU卡数量。 1 分布式策略 是 运行模式,支持以下取值: - MirroredStrategy(单机多卡)
- ExascaleStrategy(多机多卡)
MirroredStrategy(单机多卡) - 输出桩
输出桩(从左到右) 数据类型 下游组件 输出模型 OSS路径。该路径是您在字段设置页签的模型存储路径参数配置的OSS路径,训练生成SavedModel格式的模型存储在该路径下。 文本分类预测(MaxCompute)
PAI命令及说明
您也可以通过如下PAI命令,使用文本分类(MaxCompute)算法。
pai -name easy_transfer_app_ext
-Dmode=train
-DmodelName=text_classify_bert
-DinputTable=odps://${your_project}/tables/${train},odps://${your_project}/tables/${dev}
-DfirstSequence=content
-DlabelName=label
-DlabelEnumerateValues=100,101,102,103,104,105,106,107,108,109,110,112,113,114,115,116
-DsequenceLength=64
-DcheckpointDir=oss://${your_bucket}/${your_path}
-DbatchSize=32
-DnumEpochs=1
-DoptimizerType=adam
-DlearningRate=2e-5
-DuserDefinedParameters='pretrain_model_name_or_path=pai-bert-base-zh'
-Dbuckets=oss://${your_bucket}/
-Darn=${your_role_arn}
-DossHost=${your_host}
PAI命令中的参数详情如下表所示。参数名称 | 是否必选 | 描述 | 类型 | 默认值 |
---|---|---|---|---|
mode | 是 | 模式,支持以下取值:
| STRING | 无 |
modelName | 是 | 模型名称,与应用一一对应。支持以下模型:
| STRING | 无 |
inputTable | 是 | MaxCompute输入表的表名。 | STRING | 无 |
firstSequence | 是 | 文本序列在输入表中对应的列名。 | STRING | 无 |
labelName | 是 | 分类标签对应的列名。 | STRING | 无 |
labelEnumerateValues | 是 | 枚举出所有标签值。 | STRING | 无 |
sequenceLength | 是 | 序列整体最大长度。 | BIGINT | 无 |
checkpointDir | 是 | 模型Checkpoint的存储路径。例如oss://easynlp-sh/text_match/ 。 | STRING | 无 |
batchSize | 是 | 特征提取的批大小。 | BIGINT | 无 |
numEpochs | 是 | 训练的轮次。 | BIGINT | 无 |
optimizerType | 是 | 优化器,例如adam。 | STRING | 无 |
learningRate | 是 | 优化器的学习率,例如3e-5。 | DOUBLE | 无 |
userDefinedParameters | 是 | 额外的参数。例如修改预训练模型,您可以将该参数配置为pretrain_model_name_or_path=pai-bert-base-zh 。 | STRING | 无 |
buckets | 是 | 需要鉴权的OSS Bucket,与CheckpointDir对应。例如oss://easynlp-sh/ 。 | STRING | 无 |
arn | 是 | 您的ARN配置。 | STRING | 无 |
ossHost | 是 | 您Bucket对应的OSS Host。 | STRING | 无 |
运行完成后,输出的模型存储在PAI命令的CheckpointDir参数中配置的OSS路径下,您可以登录OSS管理控制台查看模型信息。输出结果的示例如下图所示。
上述文件包括:

- 模型中间结果:avg_loss是训练loss,eval是评测结果,variables是模型参数,其他的文件为模型Checkpoint和Meta信息。
- 可部署的模型:deployment存放可以部署的模型,可以直接对接PAI-EAS服务。
支持的计算资源
MaxCompute
示例
- 下载训练数据集和测试数据集。本示例使用的训练数据集和测试数据集是通过
\t
分隔的CSV文件,文件内容示例如下所示。
本示例使用的数据来自TNEWS' 今日头条中文新闻(短文本)分类,详情请参见CLUE benchmark。为了演示教程,训练集取了1000个样本,测试集取了100个样本,且样本共有如下四个字段:53360 美少女甜甜圈自拍,迷之角度竟这么好看,美吸引一切事物 102 news_entertainment 自拍,美少女,经纪人,甜甜圈 53361 重庆美食打卡,带你领略舌尖上的重庆 102 news_food 重庆,美食,美味
- example_id:样本ID信息。
- sentence:文本信息,对应组件里的文本列选择参数。
- label:标签信息,对应组件里的标签列选择参数。
- label_str:额外信息。
- keywords:额外信息。
- 通过MaxCompute客户端,分别为训练数据集和测试数据集创建数据表,表字段定义为example_id、sentence、label、label_str及keywords。关于MaxCompute客户端的使用,请参见使用客户端(odpscmd)连接。
CREATE TABLE ez_text_classify_train( example_id INT, sentence STRING, label STRING, label_str STRING, keywords STRING); CREATE TABLE ez_text_classify_dev( example_id INT, sentence STRING, label STRING, label_str STRING, keywords STRING);
- 将下载的训练数据集train.csv和测试数据集dev.csv分别上传到已创建的MaxCompute表中。关于如何使用MaxCompute客户端上传数据,请参见Tunnel命令。
odpscmd -e --config=${odps_config} "tunnel upload train.csv ez_text_classify_test_train -fd \t;" odpscmd -e --config=${odps_config} "tunnel upload dev.csv ez_text_classify_test_dev -fd \t;"
- 创建如下工作流。
区域 描述 ① 配置读数据表-1的表名参数为已创建的ez_text_classify_test_train训练表。 ② 配置读数据表-2的表名参数为已创建的ez_text_classify_test_dev测试表。 ③ 根据训练数据和测试数据,您需要注意以下参数配置,其他参数配置请参见上文的可视化配置参数: - 配置文本列选择为sentence。
- 配置标签列选择为label。
- 配置标签枚举值为
100,101,102,103,104,105,106,107,108,109,110,112,113,114,115,116
。
④ 配置读数据表-3的表名参数为已创建的ez_text_classify_test_dev测试表。 ⑤ 使用训练好的文本分类模型对测试数据集进行预测。文本分类预测(MaxCompute)组件的配置详情,请参见文本分类预测(MaxCompute)。 - 运行工作流结束后,您可以查看输出的模型文件和预测结果。
- 在模型存储路径参数配置的OSS路径下查看输出的文本分类模型。
- 右键单击文本分类预测(MaxCompute)组件,在快捷菜单,单击 ,查看预测结果。