通过CatBoost工具训练GBDT模型

更新时间:2025-03-27 10:31:06

在广告点击预测、游戏用户付费或流失预测以及邮件自动分类等数据挖掘场景中,通常需要基于历史数据训练出用于分类的模型,以便预测后续行为。您可以使用云原生数据仓库 AnalyticDB MySQL 版Spark,通过CatBoost工具基于GBDT模型实现数据的分类和预测。

前提条件

  • 集群的产品系列为企业版、基础版或湖仓版

  • 集群与OSS存储空间位于相同地域。

  • 已在企业版、基础版或湖仓版集群中创建Job型资源组。具体操作,请参见新建和管理资源组

  • 已创建数据库账号。

操作步骤

步骤一:准备Maven依赖并上传至OSS

  1. 您可以通过以下任一方式获取Maven依赖:

    • 通过链接下载Maven依赖的Jar包。

    • IDEApom.xml中配置Maven依赖。代码如下:

      <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
               xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
          <modelVersion>4.0.0</modelVersion>
      
          <groupId>com.aliyun.adb.spark</groupId>
          <artifactId>CatBoostDemo</artifactId>
          <version>1.0</version>
          <packaging>jar</packaging>
      
          <properties>
              <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
          </properties>
      
          <dependencies>
              <dependency>
                  <groupId>ai.catboost</groupId>
                  <artifactId>catboost-spark_3.5_2.12</artifactId>
                  <version>1.2.7</version>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-launcher_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-core_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-sql_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>com.aliyun.oss</groupId>
                  <artifactId>aliyun-sdk-oss</artifactId>
                  <version>3.16.2</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-mllib_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
          </dependencies>
      
          <build>
              <plugins>
                  <plugin>
                      <groupId>net.alchim31.maven</groupId>
                      <artifactId>scala-maven-plugin</artifactId>
                      <version>4.4.0</version>
                      <executions>
                          <execution>
                              <goals>
                                  <goal>compile</goal>
                                  <goal>testCompile</goal>
                              </goals>
                          </execution>
                      </executions>
                  </plugin>
                  <plugin>
                      <groupId>org.apache.maven.plugins</groupId>
                      <artifactId>maven-shade-plugin</artifactId>
                      <version>3.1.1</version>
                      <configuration>
                          <createDependencyReducedPom>false</createDependencyReducedPom>
                      </configuration>
                      <executions>
                          <execution>
                              <phase>package</phase>
                              <goals>
                                  <goal>shade</goal>
                              </goals>
                          </execution>
                      </executions>
                  </plugin>
              </plugins>
          </build>
      </project>
  2. 如果是在IDEApom.xml中配置的Maven依赖,需执行mvn clean package -DskipTests命令打包。通过链接下载的Maven依赖不需要打包,可跳过该步骤。

  3. Maven依赖上传OSS中。

步骤二:编写应用程序并上传至OSS

  1. 编写应用程序。

    Scala
    Python
    package com.aliyun.adb
    
    import ai.catboost.spark.{CatBoostClassificationModel, CatBoostClassifier, Pool}
    import org.apache.spark.ml.linalg.{SQLDataTypes, Vectors}
    import org.apache.spark.sql.types.{StringType, StructField, StructType}
    import org.apache.spark.sql.{Row, SparkSession}
    
    object CatBoostDemo {
      def main(args: Array[String]): Unit = {
        val spark = SparkSession
          .builder()
          .appName("CatBoost Example")
          .getOrCreate()
    
        val srcDataSchema = Seq(
          StructField("features", SQLDataTypes.VectorType),
          StructField("label", StringType)
        )
    
        val trainData = Seq(
          Row(Vectors.dense(0.1, 0.2, 0.11), "1"),
          Row(Vectors.dense(0.97, 0.82, 0.33), "2"),
          Row(Vectors.dense(0.13, 0.22, 0.23), "1"),
          Row(Vectors.dense(0.8, 0.62, 0.0), "0")
        )
    
        val trainDf = spark.createDataFrame(spark.sparkContext.parallelize(trainData), StructType(srcDataSchema))
        val trainPool = new Pool(trainDf)
    
        val evalData = Seq(
          Row(Vectors.dense(0.22, 0.33, 0.9), "2"),
          Row(Vectors.dense(0.11, 0.1, 0.21), "0"),
          Row(Vectors.dense(0.77, 0.0, 0.0), "1")
        )
    
        val evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema))
        val evalPool = new Pool(evalDf)
    
        val classifier = new CatBoostClassifier
    
        // train a model
        val model = classifier.fit(trainPool, Array[Pool](evalPool))
    
        // apply the model
        val predictions = model.transform(evalPool.data)
        println("predictions")
        predictions.show()
    
        // save the model as a local file in CatBoost native format
        val savedNativeModelPath = "./multiclass_model.cbm"
        model.saveNativeModel(savedNativeModelPath)
    
        // load the model as a local file in CatBoost native format
        val loadedNativeModel = CatBoostClassificationModel.loadNativeModel(savedNativeModelPath)
    
        val predictionsFromLoadedNativeModel = loadedNativeModel.transform(evalPool.data)
        println("predictionsFromLoadedNativeModel")
        predictionsFromLoadedNativeModel.show()
        System.exit(0)
      }
    }
    
    重要

    CatBoost依赖于SparkSession初始化后的上下文环境,因此必须在创建SparkSession对象后,再执行import catboost_spark,提前执行可能导致依赖加载失败。

    from pyspark.sql import Row,SparkSession
    from pyspark.ml.linalg import Vectors, VectorUDT
    from pyspark.sql.types import *
    
    
    spark = SparkSession.builder.getOrCreate()
    
    import catboost_spark
    
    srcDataSchema = [
        StructField("features", VectorUDT()),
        StructField("label", StringType())
    ]
    
    trainData = [
        Row(Vectors.dense(0.1, 0.2, 0.11), "1"),
        Row(Vectors.dense(0.97, 0.82, 0.33), "2"),
        Row(Vectors.dense(0.13, 0.22, 0.23), "1"),
        Row(Vectors.dense(0.8, 0.62, 0.0), "0")
    ]
    
    trainDf = spark.createDataFrame(spark.sparkContext.parallelize(trainData), StructType(srcDataSchema))
    trainPool = catboost_spark.Pool(trainDf)
    
    evalData = [
        Row(Vectors.dense(0.22, 0.33, 0.9), "2"),
        Row(Vectors.dense(0.11, 0.1, 0.21), "0"),
        Row(Vectors.dense(0.77, 0.0, 0.0), "1")
    ]
    
    evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema))
    evalPool = catboost_spark.Pool(evalDf)
    
    classifier = catboost_spark.CatBoostClassifier()
    
    # train a model
    model = classifier.fit(trainPool, evalDatasets=[evalPool])
    
    # apply the model
    predictions = model.transform(evalPool.data)
    predictions.show()
    
    # save the model as a local file in CatBoost native format
    savedNativeModelPath = './multiclass_model.cbm'
    model.saveNativeModel(savedNativeModelPath)
    # load the model as a local file in CatBoost native format
    
    loadedNativeModel = catboost_spark.CatBoostClassificationModel.loadNativeModel(savedNativeModelPath)
    
    predictionsFromLoadedNativeModel = loadedNativeModel.transform(evalPool.data)
    predictionsFromLoadedNativeModel.show()
    
  2. 如果是Scala程序,则需要将其打成Jar包。Python文件不需要打包,可跳过此步骤。

  3. 将打好的Jar包或.py的文件上传OSS。

步骤三:提交Spark作业

  1. 登录云原生数据仓库AnalyticDB MySQL控制台,在左上角选择集群所在地域。在左侧导航栏,单击集群列表。在集群列表上方,选择产品系列,然后单击目标集群ID。

  2. 在左侧导航栏,单击作业开发 > Spark Jar开发

  3. 选择Job型资源组和Spark作业类型。本文示例为Batch

  4. 根据步骤二中编写的应用程序,在编辑器中输入以下代码后,单击立即执行

    Scala应用
    Python应用
     {
      "name": "CatBoostDemo",
      "file": "oss://testBucketName/original-LightgbmDemo-1.0.jar",
      "jars": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar",
      "ClassName":"com.aliyun.adb.CatBoostDemo",
        "conf": {
            "spark.driver.resourceSpec": "large",
            "spark.executor.instances": 2,
            "spark.executor.resourceSpec": "medium",
            "spark.executor.memoryOverhead": "4096",
            "spark.task.cpus": 2,
            "spark.adb.version": "3.5"
        }
     {
      "name": "CatBoostDemo",
      "file": "oss://testBucketName/GBDT/lightgbm_spark_20241227.py",
      "jars": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar",
      "pyFiles": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar",
        "conf": {
            "spark.driver.resourceSpec": "large",
            "spark.executor.instances": 2,
            "spark.executor.resourceSpec": "medium",
            "spark.executor.memoryOverhead": "4096",
            "spark.task.cpus": 2,
            "spark.adb.version": "3.5"
        }
    说明

    本文Python应用程序需调用Jar包里的方法,因此需要在Spark Jar作业代码中添加jars参数,并指向步骤一Maven依赖所在的OSS路径。

    参数说明:

    参数

    是否必填

    说明

    参数

    是否必填

    说明

    name

    Spark作业名称。

    file

    • Scala:步骤二中编写的Scala应用所在的OSS路径。

    • Python:步骤二中编写的Python应用所在的OSS路径。

    jars

    步骤一中准备的Maven依赖所在的OSS路径。

    ClassName

    条件必填

    Scala应用入口类名称。提交Scala应用时必填。

    pyFiles

    条件必填

    步骤一中准备的Maven依赖所在的OSS路径。提交Python应用时必填。

    spark.adb.version

    Spark版本,此处必须显式指定为3.5。

    spark.task.cpus

    Spark Executor资源规格所对应的CPU核数, 以确保每个Executor中只有一个CatBoost Worker的进程。

    例如Spark Executor资源的规格为medium(即spark.executor.resourceSpecmedium),此处则需要配置为2。

    其他conf参数

    与开源Spark中的配置项基本一致,参数格式为key: value形式,多个参数之间以英文逗号(,)分隔。与开源Spark用法不一致的配置参数及AnalyticDB for MySQL特有的配置参数,请参见Spark应用配置参数说明

  5. (可选)在应用列表中,单击目标作业操作列的日志,在日志中可以查看返回结果。

  • 本页导读 (1)
  • 前提条件
  • 操作步骤
  • 步骤一:准备Maven依赖并上传至OSS
  • 步骤二:编写应用程序并上传至OSS
  • 步骤三:提交Spark作业