文本匹配训练(MaxCompute)
文本匹配训练(MaxCompute)组件的输入为两个句子,输出它们是否匹配。本文介绍文本匹配训练(MaxCompute)组件的配置方法及使用示例。
算法简介
文本匹配训练算法采用BERT类的训练模型,输入两个句子,输出它们是否匹配。BERT文本匹配本质上是一个双句分类的任务,因此可以复用文本分类的配置,仅将输入调整为两个句子即可。模型如下所示。
您可以通过以下两种方式使用文本匹配算法:
在Designer中,通过可视化的方式配置组件参数,详情请参见下文的可视化配置参数。
文本匹配算法组件位于组件库自然语言处理文件夹。
在DataWorks的ODPS SQL节点中,通过PAI命令的方式调用算法,详情请参见下文的PAI命令及说明。关于DataWorks的ODPS SQL节点请参见开发ODPS SQL任务。
前提条件
已开通OSS并完成授权,详情请参见开通OSS服务和云产品依赖与授权:Designer。
使用限制
仅Designer提供该算法组件,使用前必须开通MaxCompute资源组,并且使用GPU。
可视化配置参数
输入桩
输入桩(从左到右)
数据类型
建议上游组件
是否必选
训练数据
MaxCompute表
是
测试数据
MaxCompute表
是
组件配置
页签
参数
是否必选
描述
默认值
字段设置
第一文本列选择
是
第一个文本序列在输入格式中对应的列名。
无
第二文本列选择
是
第二个文本序列在输入格式中对应的列名。
无
标签列选择
是
标签对应的列名。
无
标签枚举值
是
您需要枚举出所有标签,通常为
0,1
。无
模型存储路径
是
模型Checkpoint的存储路径。
无
参数设置
优化器类型
否
选择优化器类型,支持以下取值:
adam
adagrad
lamb
adam
batchSize
否
特征提取的批大小,取值为INT类型。
32
sequenceLength
否
序列整体最大长度,取值范围为1~512。
128
numEpochs
否
训练的轮次,取值为INT类型。
3
学习率
否
优化器的学习率,取值为FLOAT类型。
2e-5
pretrainModelNameOrPath
否
选择预训练模型,例如:
pai-bert-base-zh
pai-bert-small-zh
pai-bert-tiny-zh
pai-bert-base-zh
模型额外参数
否
额外的参数。例如修改预训练模型,您可以将该参数配置为
pretrain_model_name_or_path=pai-bert-base-zh
。pretrain_model_name_or_path=pai-bert-base-zh
执行调优
指定Worker数
否
用于计算的Worker数量。
1
指定Worker的GPU卡数
否
每个Worker中的GPU卡数量。
1
指定Worker的CPU卡数
否
每个Worker中的CPU卡数量。
1
分布式策略
是
运行模式,支持以下取值:
MirroredStrategy(单机多卡)
ExascaleStrategy(多机多卡)
MirroredStrategy(单机多卡)
输出桩
输出桩(从左到右)
数据类型
下游组件
输出模型
OSS路径。该路径是您在字段设置页签的模型存储路径参数配置的OSS路径,训练生成SavedModel格式的模型存储在该路径下。
文本匹配预测(MaxCompute)
PAI命令及说明
您也可以通过如下PAI命令,使用文本匹配算法。
pai -name easy_transfer_app_ext
-Dmode=train
-DmodelName=text_match_bert
-DinputTable=odps://${your_project}/tables/${train},odps://${your_project}/tables/${dev}
-DfirstSequence=query1
-DsecondSequence=query2
-DlabelName=is_same_question
-DlabelEnumerateValues=0,1
-DsequenceLength=64
-DcheckpointDir=oss://${your_bucket}/${your_path}
-DbatchSize=32
-DnumEpochs=1
-DoptimizerType=adam
-DlearningRate=2e-5
-DuserDefinedParameters=' pretrain_model_name_or_path=pai-bert-base-zh'
-Dbuckets=oss://${your_bucket}/
-Darn=${your_role_arn}
-DossHost=${your_host}
PAI命令中的参数详情如下表所示。
参数名称 | 是否必选 | 描述 | 类型 | 默认值 |
mode | 是 | 模式,支持以下取值:
| STRING | 无 |
modelName | 是 | 模型名称,与应用一一对应。支持以下模型:
| STRING | 无 |
inputTable | 是 | MaxCompute输入表的表名。 | STRING | 无 |
firstSequence | 是 | 第一个文本序列在输入表中对应的列名。 | STRING | 无 |
secondSequence | 是 | 第二个文本序列在输入格式中对应的列名。 | STRING | 无 |
labelName | 是 | 分类标签对应的列名。 | STRING | 无 |
labelEnumerateValues | 是 | 枚举出所有标签值。 | STRING | 无 |
sequenceLength | 是 | 序列整体最大长度。 | BIGINT | 无 |
checkpointDir | 是 | 模型Checkpoint的存储路径。例如 | STRING | 无 |
batchSize | 是 | 特征提取的批大小。 | BIGINT | 无 |
numEpochs | 是 | 训练的轮次。 | BIGINT | 无 |
optimizerType | 是 | 优化器,例如adam。 | STRING | 无 |
learningRate | 是 | 优化器的学习率,例如3e-5。 | FLOAT | 无 |
userDefinedParameters | 是 | 额外的参数。例如修改预训练模型,您可以将该参数配置为 | STRING | 无 |
buckets | 是 | 需要鉴权的OSS Bucket,与CheckpointDir对应。例如 | STRING | 无 |
arn | 是 | 您的ARN配置。 | STRING | 无 |
ossHost | 是 | 您Bucket对应的OSS Host。 | STRING | 无 |
运行完成后,输出的模型存储在PAI命令的CheckpointDir参数中配置的OSS路径下,您可以登录OSS管理控制台查看模型信息。输出结果的示例如下图所示。上述文件包括:
模型中间结果:avg_loss是训练loss,eval是评测结果,variables是模型参数,其他的文件为模型Checkpoint和Meta信息。
可部署的模型:deployment存放可以部署的模型,可以直接对接EAS服务。
支持的计算资源
MaxCompute
示例
本示例使用的训练数据集和测试数据集是通过
\t
分隔的CSV文件。通过MaxCompute客户端,分别为训练数据集和测试数据集创建数据表,表字段定义为is_same_question、sid1、sid2、query1及query2。关于MaxCompute客户端的使用,请参见使用客户端(odpscmd)连接。
drop table if exists modelzoo_example_train; create table modelzoo_example_train(is_same_question STRING, sid1 STRING, sid2 STRING, query1 STRING,query2 STRING); drop table if exists modelzoo_example_dev; create table modelzoo_example_dev(is_same_question STRING, sid1 STRING, sid2 STRING, query1 STRING,query2 STRING);
将下载的训练数据集train.csv和测试数据集dev.csv分别上传到已创建的MaxCompute表中。关于如何使用MaxCompute客户端上传数据,请参见Tunnel命令。
odpscmd -e --config=${odps_config} "tunnel upload train.csv modelzoo_example_train -fd \t;" odpscmd -e --config=${odps_config} "tunnel upload dev.csv modelzoo_example_dev -fd \t;"
创建如下工作流。
区域
描述
①
配置读数据表-1的表名参数为已创建的modelzoo_example_train训练表。
②
配置读数据表-2的表名参数为已创建的modelzoo_example_dev测试表。
③
根据训练数据和测试数据,您需要注意以下参数配置,其他参数配置请参见上文的可视化配置参数:
配置第一文本列选择为query1。
配置第二文本列选择为query2。
配置标签列选择为is_same_question。
配置标签枚举值为
0,1
。
运行实验,结束后可以在模型存储路径参数配置的OSS路径下查看输出的文本匹配模型。