文本匹配训练(MaxCompute)组件的输入为两个句子,输出它们是否匹配。本文介绍文本匹配训练(MaxCompute)组件的配置方法及使用示例。

算法简介

文本匹配训练算法采用BERT类的训练模型,输入两个句子,输出它们是否匹配。BERT文本匹配本质上是一个双句分类的任务,因此可以复用文本分类的配置,仅将输入调整为两个句子即可。模型如下所示。Bert文本分类算法示意图
您可以通过以下两种方式使用文本匹配算法:
  • 在PAI-Designer中,通过可视化的方式配置组件参数,详情请参见下文的可视化配置参数

    文本匹配算法组件位于组件库自然语言处理文件夹。

  • 在DataWorks的ODPS SQL节点中,通过PAI命令的方式调用算法,详情请参见下文的PAI命令及说明。关于DataWorks的ODPS SQL节点请参见创建ODPS SQL节点

前提条件

已开通OSS并完成授权,详情请参见开通OSS服务PAI访问云产品授权:OSS

使用限制

仅PAI-Designer提供该算法组件,使用前必须开通MaxCompute资源组,并且使用GPU。

可视化配置参数

  • 输入桩
    输入桩(从左到右) 数据类型 建议上游组件 是否必选
    训练数据 MaxCompute表 读数据表
    测试数据 MaxCompute表 读数据表
  • 组件配置
    页签 参数 是否必选 描述 默认值
    字段设置 第一文本列选择 第一个文本序列在输入格式中对应的列名。
    第二文本列选择 第二个文本序列在输入格式中对应的列名。
    标签列选择 标签对应的列名。
    标签枚举值 您需要枚举出所有标签,通常为0,1
    模型存储路径 模型Checkpoint的存储路径。
    参数设置 优化器类型 选择优化器类型,支持以下取值:
    • adam
    • adagrad
    • lamb
    adam
    batchSize 特征提取的批大小,取值为INT类型。 32
    sequenceLength 序列整体最大长度,取值范围为1~512。 128
    numEpochs 训练的轮次,取值为INT类型。 3
    学习率 优化器的学习率,取值为FLOAT类型。 2e-5
    pretrainModelNameOrPath 选择预训练模型,例如:
    • pai-bert-base-zh
    • pai-bert-small-zh
    • pai-bert-tiny-zh
    pai-bert-base-zh
    模型额外参数 额外的参数。例如修改预训练模型,您可以将该参数配置为pretrain_model_name_or_path=pai-bert-base-zh pretrain_model_name_or_path=pai-bert-base-zh
    执行调优 指定Worker数 用于计算的Worker数量。 1
    指定Worker的GPU卡数 每个Worker中的GPU卡数量。 1
    指定Worker的CPU卡数 每个Worker中的CPU卡数量。 1
    分布式策略 运行模式,支持以下取值:
    • MirroredStrategy(单机多卡)
    • ExascaleStrategy(多机多卡)
    MirroredStrategy(单机多卡)
  • 输出桩
    输出桩(从左到右) 数据类型 下游组件
    输出模型 OSS路径。该路径是您在字段设置页签的模型存储路径参数配置的OSS路径,训练生成SavedModel格式的模型存储在该路径下。 文本匹配预测(MaxCompute)

PAI命令及说明

您也可以通过如下PAI命令,使用文本匹配算法。
pai -name easy_transfer_app_ext
  -Dmode=train
  -DmodelName=text_match_bert
  -DinputTable=odps://${your_project}/tables/${train},odps://${your_project}/tables/${dev}
  -DfirstSequence=query1
  -DsecondSequence=query2
  -DlabelName=is_same_question
  -DlabelEnumerateValues=0,1
  -DsequenceLength=64
  -DcheckpointDir=oss://${your_bucket}/${your_path}
  -DbatchSize=32
  -DnumEpochs=1
  -DoptimizerType=adam
  -DlearningRate=2e-5
  -DuserDefinedParameters=' pretrain_model_name_or_path=pai-bert-base-zh'
  -Dbuckets=oss://${your_bucket}/
  -Darn=${your_role_arn}
  -DossHost=${your_host}
PAI命令中的参数详情如下表所示。
参数名称 是否必选 描述 类型 默认值
mode 模式,支持以下取值:
  • train:训练
  • evaluate:评估
  • predict:预测
STRING
modelName 模型名称,与应用一一对应。支持以下模型:
  • text_classify_bert:文本分类
  • text_match_bert:文本匹配
  • sequence_labeling_bert:序列标注
STRING
inputTable MaxCompute输入表的表名。 STRING
firstSequence 第一个文本序列在输入表中对应的列名。 STRING
secondSequence 第二个文本序列在输入格式中对应的列名。 STRING
labelName 分类标签对应的列名。 STRING
labelEnumerateValues 枚举出所有标签值。 STRING
sequenceLength 序列整体最大长度。 BIGINT
checkpointDir 模型Checkpoint的存储路径。例如oss://easynlp-sh/text_match/ STRING
batchSize 特征提取的批大小。 BIGINT
numEpochs 训练的轮次。 BIGINT
optimizerType 优化器,例如adam。 STRING
learningRate 优化器的学习率,例如3e-5。 FLOAT
userDefinedParameters 额外的参数。例如修改预训练模型,您可以将该参数配置为pretrain_model_name_or_path=pai-bert-base-zh STRING
buckets 需要鉴权的OSS Bucket,与CheckpointDir对应。例如oss://easynlp-sh/ STRING
arn 您的ARN配置。 STRING
ossHost 您Bucket对应的OSS Host。 STRING
运行完成后,输出的模型存储在PAI命令的CheckpointDir参数中配置的OSS路径下,您可以登录OSS管理控制台查看模型信息。输出结果的示例如下图所示。输出结果上述文件包括:
  • 模型中间结果:avg_loss是训练loss,eval是评测结果,variables是模型参数,其他的文件为模型Checkpoint和Meta信息。
  • 可部署的模型:deployment存放可以部署的模型,可以直接对接PAI-EAS服务。

支持的计算资源

MaxCompute

示例

  1. 下载训练数据集测试数据集

    本示例使用的训练数据集和测试数据集是通过\t分隔的CSV文件。

  2. 通过MaxCompute客户端,分别为训练数据集和测试数据集创建数据表,表字段定义为is_same_questionsid1sid2query1query2。关于MaxCompute客户端的使用,请参见使用客户端(odpscmd)连接
    drop table if exists modelzoo_example_train;
    create table modelzoo_example_train(is_same_question STRING, sid1 STRING, sid2 STRING, query1 STRING,query2 STRING);
    
    drop table if exists modelzoo_example_dev;
    create table modelzoo_example_dev(is_same_question STRING, sid1 STRING, sid2 STRING, query1 STRING,query2 STRING);
  3. 将下载的训练数据集train.csv和测试数据集dev.csv分别上传到已创建的MaxCompute表中。关于如何使用MaxCompute客户端上传数据,请参见Tunnel命令
    odpscmd -e --config=${odps_config} "tunnel upload train.csv modelzoo_example_train -fd \t;"
    odpscmd -e --config=${odps_config} "tunnel upload dev.csv modelzoo_example_dev -fd \t;"
  4. 创建如下工作流。文本匹配实验
    区域 描述
    配置读数据表-1表名参数为已创建的modelzoo_example_train训练表。
    配置读数据表-2表名参数为已创建的modelzoo_example_dev测试表。
    根据训练数据和测试数据,您需要注意以下参数配置,其他参数配置请参见上文的可视化配置参数
    • 配置第一文本列选择query1
    • 配置第二文本列选择query2
    • 配置标签列选择is_same_question
    • 配置标签枚举值0,1
  5. 运行实验,结束后可以在模型存储路径参数配置的OSS路径下查看输出的文本匹配模型。