通过CatBoost工具训练GBDT模型
在广告点击预测、游戏用户付费或流失预测以及邮件自动分类等数据挖掘场景中,通常需要基于历史数据训练出用于分类的模型,以便预测后续行为。您可以使用云原生数据仓库 AnalyticDB MySQL 版Spark,通过CatBoost工具基于GBDT模型实现数据的分类和预测。
前提条件
集群的产品系列为企业版、基础版或湖仓版。
集群与OSS存储空间位于相同地域。
已在企业版、基础版或湖仓版集群中创建Job型资源组。具体操作,请参见新建和管理资源组。
已创建数据库账号。
如果是通过阿里云账号访问,只需创建高权限账号。具体操作,请参见创建高权限账号。
如果是通过RAM用户访问,需要创建高权限账号和普通账号并且将RAM用户绑定到普通账号上。具体操作,请参见创建数据库账号和绑定或解绑RAM用户与数据库账号。
操作步骤
步骤一:准备Maven依赖并上传至OSS
您可以通过以下任一方式获取Maven依赖:
通过链接下载Maven依赖的Jar包。
在IDEA的pom.xml中配置Maven依赖。代码如下:
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>com.aliyun.adb.spark</groupId> <artifactId>CatBoostDemo</artifactId> <version>1.0</version> <packaging>jar</packaging> <properties> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> </properties> <dependencies> <dependency> <groupId>ai.catboost</groupId> <artifactId>catboost-spark_3.5_2.12</artifactId> <version>1.2.7</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-launcher_2.12</artifactId> <version>3.5.1</version> <scope>provided</scope> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_2.12</artifactId> <version>3.5.1</version> <scope>provided</scope> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-sql_2.12</artifactId> <version>3.5.1</version> <scope>provided</scope> </dependency> <dependency> <groupId>com.aliyun.oss</groupId> <artifactId>aliyun-sdk-oss</artifactId> <version>3.16.2</version> <scope>provided</scope> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-mllib_2.12</artifactId> <version>3.5.1</version> <scope>provided</scope> </dependency> </dependencies> <build> <plugins> <plugin> <groupId>net.alchim31.maven</groupId> <artifactId>scala-maven-plugin</artifactId> <version>4.4.0</version> <executions> <execution> <goals> <goal>compile</goal> <goal>testCompile</goal> </goals> </execution> </executions> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-shade-plugin</artifactId> <version>3.1.1</version> <configuration> <createDependencyReducedPom>false</createDependencyReducedPom> </configuration> <executions> <execution> <phase>package</phase> <goals> <goal>shade</goal> </goals> </execution> </executions> </plugin> </plugins> </build> </project>
如果是在IDEA的pom.xml中配置的Maven依赖,需执行
mvn clean package -DskipTests
命令打包。通过链接下载的Maven依赖不需要打包,可跳过该步骤。将Maven依赖上传至OSS中。
步骤二:编写应用程序并上传至OSS
编写应用程序。
ScalaPythonpackage com.aliyun.adb import ai.catboost.spark.{CatBoostClassificationModel, CatBoostClassifier, Pool} import org.apache.spark.ml.linalg.{SQLDataTypes, Vectors} import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.sql.{Row, SparkSession} object CatBoostDemo { def main(args: Array[String]): Unit = { val spark = SparkSession .builder() .appName("CatBoost Example") .getOrCreate() val srcDataSchema = Seq( StructField("features", SQLDataTypes.VectorType), StructField("label", StringType) ) val trainData = Seq( Row(Vectors.dense(0.1, 0.2, 0.11), "1"), Row(Vectors.dense(0.97, 0.82, 0.33), "2"), Row(Vectors.dense(0.13, 0.22, 0.23), "1"), Row(Vectors.dense(0.8, 0.62, 0.0), "0") ) val trainDf = spark.createDataFrame(spark.sparkContext.parallelize(trainData), StructType(srcDataSchema)) val trainPool = new Pool(trainDf) val evalData = Seq( Row(Vectors.dense(0.22, 0.33, 0.9), "2"), Row(Vectors.dense(0.11, 0.1, 0.21), "0"), Row(Vectors.dense(0.77, 0.0, 0.0), "1") ) val evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema)) val evalPool = new Pool(evalDf) val classifier = new CatBoostClassifier // train a model val model = classifier.fit(trainPool, Array[Pool](evalPool)) // apply the model val predictions = model.transform(evalPool.data) println("predictions") predictions.show() // save the model as a local file in CatBoost native format val savedNativeModelPath = "./multiclass_model.cbm" model.saveNativeModel(savedNativeModelPath) // load the model as a local file in CatBoost native format val loadedNativeModel = CatBoostClassificationModel.loadNativeModel(savedNativeModelPath) val predictionsFromLoadedNativeModel = loadedNativeModel.transform(evalPool.data) println("predictionsFromLoadedNativeModel") predictionsFromLoadedNativeModel.show() System.exit(0) } }
CatBoost依赖于SparkSession初始化后的上下文环境,因此必须在创建SparkSession对象后,再执行
import catboost_spark
,提前执行可能导致依赖加载失败。from pyspark.sql import Row,SparkSession from pyspark.ml.linalg import Vectors, VectorUDT from pyspark.sql.types import * spark = SparkSession.builder.getOrCreate() import catboost_spark srcDataSchema = [ StructField("features", VectorUDT()), StructField("label", StringType()) ] trainData = [ Row(Vectors.dense(0.1, 0.2, 0.11), "1"), Row(Vectors.dense(0.97, 0.82, 0.33), "2"), Row(Vectors.dense(0.13, 0.22, 0.23), "1"), Row(Vectors.dense(0.8, 0.62, 0.0), "0") ] trainDf = spark.createDataFrame(spark.sparkContext.parallelize(trainData), StructType(srcDataSchema)) trainPool = catboost_spark.Pool(trainDf) evalData = [ Row(Vectors.dense(0.22, 0.33, 0.9), "2"), Row(Vectors.dense(0.11, 0.1, 0.21), "0"), Row(Vectors.dense(0.77, 0.0, 0.0), "1") ] evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema)) evalPool = catboost_spark.Pool(evalDf) classifier = catboost_spark.CatBoostClassifier() # train a model model = classifier.fit(trainPool, evalDatasets=[evalPool]) # apply the model predictions = model.transform(evalPool.data) predictions.show() # save the model as a local file in CatBoost native format savedNativeModelPath = './multiclass_model.cbm' model.saveNativeModel(savedNativeModelPath) # load the model as a local file in CatBoost native format loadedNativeModel = catboost_spark.CatBoostClassificationModel.loadNativeModel(savedNativeModelPath) predictionsFromLoadedNativeModel = loadedNativeModel.transform(evalPool.data) predictionsFromLoadedNativeModel.show()
如果是Scala程序,则需要将其打成Jar包。Python文件不需要打包,可跳过此步骤。
将打好的Jar包或.py的文件上传至OSS。
步骤三:提交Spark作业
登录云原生数据仓库AnalyticDB MySQL控制台,在左上角选择集群所在地域。在左侧导航栏,单击集群列表。在集群列表上方,选择产品系列,然后单击目标集群ID。
在左侧导航栏,单击
。选择Job型资源组和Spark作业类型。本文示例为Batch。
根据步骤二中编写的应用程序,在编辑器中输入以下代码后,单击立即执行。
Scala应用Python应用{ "name": "CatBoostDemo", "file": "oss://testBucketName/original-LightgbmDemo-1.0.jar", "jars": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar", "ClassName":"com.aliyun.adb.CatBoostDemo", "conf": { "spark.driver.resourceSpec": "large", "spark.executor.instances": 2, "spark.executor.resourceSpec": "medium", "spark.executor.memoryOverhead": "4096", "spark.task.cpus": 2, "spark.adb.version": "3.5" }
{ "name": "CatBoostDemo", "file": "oss://testBucketName/GBDT/lightgbm_spark_20241227.py", "jars": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar", "pyFiles": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar", "conf": { "spark.driver.resourceSpec": "large", "spark.executor.instances": 2, "spark.executor.resourceSpec": "medium", "spark.executor.memoryOverhead": "4096", "spark.task.cpus": 2, "spark.adb.version": "3.5" }
本文Python应用程序需调用Jar包里的方法,因此需要在Spark Jar作业代码中添加
jars
参数,并指向步骤一Maven依赖所在的OSS路径。参数说明:
参数
是否必填
说明
参数
是否必填
说明
name
否
Spark作业名称。
file
是
Scala:步骤二中编写的Scala应用所在的OSS路径。
Python:步骤二中编写的Python应用所在的OSS路径。
jars
是
步骤一中准备的Maven依赖所在的OSS路径。
ClassName
条件必填
Scala应用入口类名称。提交Scala应用时必填。
pyFiles
条件必填
步骤一中准备的Maven依赖所在的OSS路径。提交Python应用时必填。
spark.adb.version
是
Spark版本,此处必须显式指定为3.5。
spark.task.cpus
是
Spark Executor资源规格所对应的CPU核数, 以确保每个Executor中只有一个CatBoost Worker的进程。
例如Spark Executor资源的规格为medium(即spark.executor.resourceSpec为medium),此处则需要配置为2。
其他conf参数
否
与开源Spark中的配置项基本一致,参数格式为
key: value
形式,多个参数之间以英文逗号(,)分隔。与开源Spark用法不一致的配置参数及AnalyticDB for MySQL特有的配置参数,请参见Spark应用配置参数说明。(可选)在应用列表中,单击目标作业操作列的日志,在日志中可以查看返回结果。
- 本页导读 (1)
- 前提条件
- 操作步骤
- 步骤一:准备Maven依赖并上传至OSS
- 步骤二:编写应用程序并上传至OSS
- 步骤三:提交Spark作业