通过LightGBM工具训练GBDT模型

更新时间:2025-03-14 10:26:59

在广告点击预测、游戏用户付费或流失预测以及邮件自动分类等数据挖掘场景中,通常需要基于历史数据训练出用于分类的模型,以便预测后续行为。您可以使用云原生数据仓库 AnalyticDB MySQL 版Spark,通过LightGBM工具基于GBDT模型实现数据的分类和预测。与单机部署的XGBoost、CatBoost相比,部署在Spark上的LightGBM能够充分利用分布式计算能力,处理TB级大规模数据。本文介绍如何通过LightGBM工具训练GBDT模型,实现数据的分类和预测。

前提条件

  • 集群的产品系列为企业版、基础版或湖仓版

  • 集群与OSS存储空间位于相同地域。

  • 已在企业版、基础版或湖仓版集群中创建Job型资源组。具体操作,请参见新建和管理资源组

  • 已创建数据库账号。

操作步骤

步骤一:准备Maven依赖并上传至OSS

  1. 您可以通过以下任一方式获取Maven依赖:

    • 通过链接下载Maven依赖的Jar包。

    • IDEApom.xml中配置Maven依赖。代码如下:

      <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
               xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
          <modelVersion>4.0.0</modelVersion>
      
          <groupId>com.aliyun.adb.spark</groupId>
          <artifactId>LightgbmDemo</artifactId>
          <version>1.0</version>
          <packaging>jar</packaging>
      
          <properties>
              <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
          </properties>
      
          <dependencies>
              <!-- https://mvnrepository.com/artifact/com.microsoft.azure/synapseml -->
              <dependency>
                  <groupId>com.microsoft.azure</groupId>
                  <artifactId>synapseml_2.12</artifactId>
                  <version>1.0.8</version>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-launcher_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-core_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-sql_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
              <dependency>
                  <groupId>org.apache.spark</groupId>
                  <artifactId>spark-mllib_2.12</artifactId>
                  <version>3.5.1</version>
                  <scope>provided</scope>
              </dependency>
          </dependencies>
      
          <build>
              <plugins>
                  <plugin>
                      <groupId>net.alchim31.maven</groupId>
                      <artifactId>scala-maven-plugin</artifactId>
                      <version>4.4.0</version>
                      <executions>
                          <execution>
                              <goals>
                                  <goal>compile</goal>
                                  <goal>testCompile</goal>
                              </goals>
                          </execution>
                      </executions>
                  </plugin>
                  <plugin>
                      <groupId>org.apache.maven.plugins</groupId>
                      <artifactId>maven-shade-plugin</artifactId>
                      <version>3.1.1</version>
                      <configuration>
                          <createDependencyReducedPom>false</createDependencyReducedPom>
                      </configuration>
                      <executions>
                          <execution>
                              <phase>package</phase>
                              <goals>
                                  <goal>shade</goal>
                              </goals>
                          </execution>
                      </executions>
                  </plugin>
              </plugins>
          </build>
      </project>
      
  2. 如果是在IDEApom.xml中配置的Maven依赖,需执行mvn clean package -DskipTests命令打包。通过链接下载的Maven依赖不需要打包,可跳过该步骤。

  3. Maven依赖上传OSS中。

步骤二:编写应用程序并上传至OSS

  1. 编写应用程序。

    Scala程序
    Python程序
    package com.aliyun.adb
    
    import com.microsoft.azure.synapse.ml.lightgbm.LightGBMClassifier
    import org.apache.spark.sql.{Dataset, Row, SparkSession}
    
    object GBDTdemo {
      def main(args: Array[String]): Unit = {
        // 注册Spark Session
        val spark = SparkSession
          .builder()
          .appName("lightgbm Example")
          .getOrCreate()
    
        // 读取训练数据
        val dataPath = s"${sys.env.getOrElse("SPARK_HOME", "/opt/spark")}/data/mllib/sample_multiclass_classification_data.txt"
        val data: Dataset[Row] = spark.read.format("libsvm").load(dataPath)
        data.show()
    
        val classifier = new LightGBMClassifier()
        classifier.setLabelCol("label")
        classifier.setFeaturesCol("features")
    
        // 设置Category型的特征列, 这里我们假设第0列和第1列是Category型的特征列
        // 示例数据中不包含Category型的特征列,这里需注释这行代码,仅为演示
        // classifier.setCategoricalSlotIndexes(Array(0, 1))
        // classifier.setCategoricalSlotNames(Array.apply("enumCol1", "enumCol2"))
    
        // 将数据分为训练集和验证集
        val Array(trainData, validData) = data.randomSplit(Array(0.6, 0.4))
        val model = classifier.fit(trainData)
        model.saveNativeModel("oss://bucket_name/model", overwrite = true)
    
        // Run predictions on the validation data
        val predictions = model.transform(validData)
        // Show the prediction results
        predictions.show()
        // print accuracy
        val accuracy = predictions.filter("label == prediction").count().toDouble / predictions.count()
        println(s"Accuracy: $accuracy")
        System.exit(0)
      }
    }
    
    import os
    
    from pyspark.sql import SparkSession
    
    if __name__ == '__main__':
        # init spark
        spark = SparkSession.builder.appName("lightgbm_spark_train").getOrCreate()
        # read data
        f = os.environ.get("SPARK_HOME") + "/data/mllib/sample_multiclass_classification_data.txt"
        df = spark.read.format("libsvm").load(f)
        # split data
        train, test = df.randomSplit([0.8, 0.2])
        # train model
        from synapse.ml.lightgbm import LightGBMClassifier
        model = LightGBMClassifier(learningRate=0.3,
                                   numIterations=20,
                                   numLeaves=4).fit(train)
        # predict
        prediction = model.transform(test)
        prediction.show()
        # stop spark
        spark.stop()
        
  2. 如果是Scala程序,则需要将其打成Jar包。Python文件不需要打包,可跳过此步骤。

  3. 将打好的Jar包或.py的文件上传OSS。

步骤三:提交Spark作业

  1. 登录云原生数据仓库AnalyticDB MySQL控制台,在左上角选择集群所在地域。在左侧导航栏,单击集群列表。在集群列表上方,选择产品系列,然后单击目标集群ID。

  2. 在左侧导航栏,单击作业开发 > Spark Jar 开发

  3. 选择Job型资源组和Spark作业类型。本文示例为Batch

  4. 根据步骤二中编写的应用程序,在编辑器中输入以下代码后,单击立即执行

    Scala程序
    Python程序
    {
        "name": "LightdbmDemo",
        "file": "oss://testBucketName/original-LightgbmDemo-1.0.jar",
        "jars": "oss://testBucketName/LightgbmDemo-1.0.jar",
        "ClassName": "com.aliyun.adb.GBDTdemo",
        "conf": {
            "spark.driver.resourceSpec": "large",
            "spark.executor.instances": 2,
            "spark.executor.resourceSpec": "medium",
            "spark.adb.version": "3.5"
        }
    }
    {
      "name": "CatBoostDemo",
      "file": "oss://testBucketName/GBDT/lightgbm_spark_20241227.py",
      "jars": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar",
      "pyFiles": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar",
      "conf": {
        "spark.driver.resourceSpec": "large",
        "spark.executor.instances": 2,
        "spark.executor.resourceSpec": "medium",
        "spark.adb.version": "3.5"
      }
    }
    说明

    本文Python应用程序需调用Scala,因此需要在Spark Jar作业代码中添加jars参数,并指向Maven依赖所在的OSS路径。

    参数说明:

    参数

    是否必填

    说明

    参数

    是否必填

    说明

    name

    Spark应用名称。

    file

    • Scala:步骤二中编写的Scala程序所在的OSS路径。

    • Python:步骤二中编写的Python程序所在的OSS路径。

    jars

    步骤一中准备的Maven依赖所在的OSS路径。

    ClassName

    条件必填

    Scala程序入口类名称。

    pyFiles

    条件必填

    步骤一中准备的Maven依赖所在的OSS路径。

    spark.adb.version

    Spark版本,此处必须显式指定为3.5。

    其他conf参数

    与开源Spark中的配置项基本一致,参数格式为key: value形式,多个参数之间以英文逗号(,)分隔。与开源Spark用法不一致的配置参数及AnalyticDB for MySQL特有的配置参数,请参见Spark应用配置参数说明

  5. (可选)在应用列表中,单击目标作业操作列的日志,在日志中可以查看返回结果。

  • 本页导读 (1)
  • 前提条件
  • 操作步骤
  • 步骤一:准备Maven依赖并上传至OSS
  • 步骤二:编写应用程序并上传至OSS
  • 步骤三:提交Spark作业