文本分类训练(MaxCompute)

更新时间: 2023-11-02 15:42:48

文本分类(MaxCompute)算法组件是以原始文本作为输入,输出类别标签。该算法组件集成了多种基于BERT的文本分类模型。本文介绍文本分类训练(MaxCompute)算法组件的配置方法及使用示例。

注意事项

该组件目前仅支持读取BOOLEAN、BIGINT、DOUBLE、STRING和DATETIME类型的数据。

算法简介

文本分类训练(MaxCompute)是通用的基于BERT的分类模型,输入文本数据,输出分类标签,适用于文本打标和文本情感分析等任务。模型结构如下图所示。Bert文本分类算法示意图

您可以通过以下两种方式使用文本分类(MaxCompute)算法:

  • Designer中,通过可视化的方式配置组件参数,详情请参见下文的可视化配置参数

    文本分类(MaxCompute)算法组件位于组件库自然语言处理文件夹。

  • 在DataWorks的ODPS SQL节点中,通过PAI命令的方式调用算法,详情请参见下文的PAI命令及说明。关于DataWorks的ODPS SQL节点请参见开发ODPS SQL任务

前提条件

已开通OSS并完成授权,详情请参见开通OSS服务云产品依赖与授权:Designer

使用限制

Designer提供该算法组件,使用前必须开通MaxCompute资源组,并且使用GPU。

可视化配置参数

  • 输入桩

    输入桩(从左到右)

    数据类型

    建议上游组件

    是否必选

    训练数据

    MaxCompute表

    读数据表

    测试数据

    MaxCompute表

    读数据表

  • 组件参数

    页签

    参数

    是否必选

    描述

    默认值

    字段设置

    文本列选择

    文本序列在输入表中对应的列名。

    标签列选择

    分类标签对应的列名。

    标签枚举值

    您需要枚举出所有标签,多个标签之间使用半角逗号(,)分隔。

    样本权重列

    样本加权列。每个样本的loss计算时,您可以加权重。

    模型存储路径

    模型Checkpoint的存储路径。例如oss://pai-asr-test-bj/test_text_match3/

    参数设置

    优化器类型

    选择优化器类型,支持以下取值:

    • adam

    • adagrad

    • lamb

    adam

    batchSize

    特征提取的批大小,取值为INT类型。

    32

    sequenceLength

    序列整体最大长度,取值范围为1~512。

    128

    numEpochs

    训练的轮次,取值为INT类型。

    3

    学习率

    优化器的学习率,取值为FLOAT类型。

    2e-5

    pretrainModelNameOrPath

    选择预训练模型,例如:

    • pai-bert-base-zh

    • pai-bert-small-zh

    • pai-bert-tiny-zh

    pai-bert-base-zh

    模型额外参数

    额外的参数。例如修改预训练模型,您可以将该参数配置为pretrain_model_name_or_path=pai-bert-base-zh

    执行调优

    指定Worker数

    用于计算的Worker数量。

    1

    指定Worker的GPU卡数

    每个Worker中的GPU卡数量。

    1

    指定Worker的CPU卡数

    每个Worker中的CPU卡数量。

    1

    分布式策略

    运行模式,支持以下取值:

    • MirroredStrategy(单机多卡)

    • ExascaleStrategy(多机多卡)

    MirroredStrategy(单机多卡)

  • 输出桩

    输出桩(从左到右)

    数据类型

    下游组件

    输出模型

    OSS路径。该路径是您在字段设置页签的模型存储路径参数配置的OSS路径,训练生成SavedModel格式的模型存储在该路径下。

    文本分类预测(MaxCompute)

PAI命令及说明

您也可以通过如下PAI命令,使用文本分类(MaxCompute)算法。

pai -name easy_transfer_app_ext
  -Dmode=train
  -DmodelName=text_classify_bert
  -DinputTable=odps://${your_project}/tables/${train},odps://${your_project}/tables/${dev}
  -DfirstSequence=content
  -DlabelName=label
  -DlabelEnumerateValues=100,101,102,103,104,105,106,107,108,109,110,112,113,114,115,116
  -DsequenceLength=64
  -DcheckpointDir=oss://${your_bucket}/${your_path}
  -DbatchSize=32
  -DnumEpochs=1
  -DoptimizerType=adam
  -DlearningRate=2e-5
  -DuserDefinedParameters='pretrain_model_name_or_path=pai-bert-base-zh'
  -Dbuckets=oss://${your_bucket}/
  -Darn=${your_role_arn}
  -DossHost=${your_host}

PAI命令中的参数详情如下表所示。

参数名称

是否必选

描述

类型

默认值

mode

模式,支持以下取值:

  • train:训练

  • evaluate:评估

  • predict:预测

STRING

modelName

模型名称,与应用一一对应。支持以下模型:

  • text_classify_bert:文本分类

  • text_match_bert:文本匹配

  • sequence_labeling_bert:序列标注

STRING

inputTable

MaxCompute输入表的表名。

STRING

firstSequence

文本序列在输入表中对应的列名。

STRING

labelName

分类标签对应的列名。

STRING

labelEnumerateValues

枚举出所有标签值。

STRING

sequenceLength

序列整体最大长度。

BIGINT

checkpointDir

模型Checkpoint的存储路径。例如oss://easynlp-sh/text_match/

STRING

batchSize

特征提取的批大小。

BIGINT

numEpochs

训练的轮次。

BIGINT

optimizerType

优化器,例如adam。

STRING

learningRate

优化器的学习率,例如3e-5。

DOUBLE

userDefinedParameters

额外的参数。例如修改预训练模型,您可以将该参数配置为pretrain_model_name_or_path=pai-bert-base-zh

STRING

buckets

需要鉴权的OSS Bucket,与CheckpointDir对应。例如oss://easynlp-sh/

STRING

arn

您的ARN配置。

STRING

ossHost

您Bucket对应的OSS Host。

STRING

运行完成后,输出的模型存储在PAI命令的CheckpointDir参数中配置的OSS路径下,您可以登录OSS管理控制台查看模型信息。输出结果的示例如下图所示。输出结果上述文件包括:

  • 模型中间结果:avg_loss是训练loss,eval是评测结果,variables是模型参数,其他的文件为模型Checkpoint和Meta信息。

  • 可部署的模型:deployment存放可以部署的模型,可以直接对接EAS服务。

支持的计算资源

MaxCompute

示例

  1. 下载训练数据集测试数据集

    本示例使用的训练数据集和测试数据集是通过\t分隔的CSV文件,文件内容示例如下所示。

    53360    美少女甜甜圈自拍,迷之角度竟这么好看,美吸引一切事物    102    news_entertainment    自拍,美少女,经纪人,甜甜圈
    53361    重庆美食打卡,带你领略舌尖上的重庆    102    news_food    重庆,美食,美味

    本示例使用的数据来自TNEWS' 今日头条中文新闻(短文本)分类,详情请参见CLUE benchmark。为了演示教程,训练集取了1000个样本,测试集取了100个样本,且样本共有如下四个字段:

    • example_id:样本ID信息。

    • sentence:文本信息,对应组件里的文本列选择参数。

    • label:标签信息,对应组件里的标签列选择参数。

    • label_str:额外信息。

    • keywords:额外信息。

  2. 通过MaxCompute客户端,分别为训练数据集和测试数据集创建数据表,表字段定义为example_idsentencelabellabel_strkeywords。关于MaxCompute客户端的使用,请参见使用客户端(odpscmd)连接

    CREATE TABLE ez_text_classify_train(
        example_id INT, sentence STRING, label STRING, label_str STRING, keywords STRING);
    CREATE TABLE ez_text_classify_dev(
        example_id INT, sentence STRING, label STRING, label_str STRING, keywords STRING);
  3. 将下载的训练数据集train.csv和测试数据集dev.csv分别上传到已创建的MaxCompute表中。关于如何使用MaxCompute客户端上传数据,请参见Tunnel命令

    odpscmd -e --config=${odps_config} "tunnel upload train.csv ez_text_classify_test_train -fd \t;"
    odpscmd -e --config=${odps_config} "tunnel upload dev.csv ez_text_classify_test_dev -fd \t;"
  4. 创建如下工作流。示例工作流

    区域

    描述

    配置读数据表-1表名参数为已创建的ez_text_classify_test_train训练表。

    配置读数据表-2表名参数为已创建的ez_text_classify_test_dev测试表。

    根据训练数据和测试数据,您需要注意以下参数配置,其他参数配置请参见上文的可视化配置参数

    • 配置文本列选择sentence

    • 配置标签列选择label

    • 配置标签枚举值100,101,102,103,104,105,106,107,108,109,110,112,113,114,115,116

    配置读数据表-3表名参数为已创建的ez_text_classify_test_dev测试表。

    使用训练好的文本分类模型对测试数据集进行预测。文本分类预测(MaxCompute)组件的配置详情,请参见文本分类预测(MaxCompute)

  5. 运行工作流结束后,您可以查看输出的模型文件和预测结果。

    • 模型存储路径参数配置的OSS路径下查看输出的文本分类模型。

    • 右键单击文本分类预测(MaxCompute)组件,在快捷菜单,单击查看数据 > 输出,查看预测结果。

阿里云首页 人工智能平台 PAI 相关技术圈