通过LightGBM工具训练GBDT模型
在广告点击预测、游戏用户付费或流失预测以及邮件自动分类等数据挖掘场景中,通常需要基于历史数据训练出用于分类的模型,以便预测后续行为。您可以使用云原生数据仓库 AnalyticDB MySQL 版Spark,通过LightGBM工具基于GBDT模型实现数据的分类和预测。与单机部署的XGBoost、CatBoost相比,部署在Spark上的LightGBM能够充分利用分布式计算能力,处理TB级大规模数据。本文介绍如何通过LightGBM工具训练GBDT模型,实现数据的分类和预测。
前提条件
集群的产品系列为企业版、基础版或湖仓版。
集群与OSS存储空间位于相同地域。
已在企业版、基础版或湖仓版集群中创建Job型资源组。具体操作,请参见新建和管理资源组。
已创建数据库账号。
如果是通过阿里云账号访问,只需创建高权限账号。具体操作,请参见创建高权限账号。
如果是通过RAM用户访问,需要创建高权限账号和普通账号并且将RAM用户绑定到普通账号上。具体操作,请参见创建数据库账号和绑定或解绑RAM用户与数据库账号。
操作步骤
步骤一:准备Maven依赖并上传至OSS
您可以通过以下任一方式获取Maven依赖:
通过链接下载Maven依赖的Jar包。
在IDEA的pom.xml中配置Maven依赖。代码如下:
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>com.aliyun.adb.spark</groupId> <artifactId>LightgbmDemo</artifactId> <version>1.0</version> <packaging>jar</packaging> <properties> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> </properties> <dependencies> <!-- https://mvnrepository.com/artifact/com.microsoft.azure/synapseml --> <dependency> <groupId>com.microsoft.azure</groupId> <artifactId>synapseml_2.12</artifactId> <version>1.0.8</version> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-launcher_2.12</artifactId> <version>3.5.1</version> <scope>provided</scope> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_2.12</artifactId> <version>3.5.1</version> <scope>provided</scope> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-sql_2.12</artifactId> <version>3.5.1</version> <scope>provided</scope> </dependency> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-mllib_2.12</artifactId> <version>3.5.1</version> <scope>provided</scope> </dependency> </dependencies> <build> <plugins> <plugin> <groupId>net.alchim31.maven</groupId> <artifactId>scala-maven-plugin</artifactId> <version>4.4.0</version> <executions> <execution> <goals> <goal>compile</goal> <goal>testCompile</goal> </goals> </execution> </executions> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-shade-plugin</artifactId> <version>3.1.1</version> <configuration> <createDependencyReducedPom>false</createDependencyReducedPom> </configuration> <executions> <execution> <phase>package</phase> <goals> <goal>shade</goal> </goals> </execution> </executions> </plugin> </plugins> </build> </project>
如果是在IDEA的pom.xml中配置的Maven依赖,需执行
mvn clean package -DskipTests
命令打包。通过链接下载的Maven依赖不需要打包,可跳过该步骤。将Maven依赖上传至OSS中。
步骤二:编写应用程序并上传至OSS
编写应用程序。
Scala程序Python程序package com.aliyun.adb import com.microsoft.azure.synapse.ml.lightgbm.LightGBMClassifier import org.apache.spark.sql.{Dataset, Row, SparkSession} object GBDTdemo { def main(args: Array[String]): Unit = { // 注册Spark Session val spark = SparkSession .builder() .appName("lightgbm Example") .getOrCreate() // 读取训练数据 val dataPath = s"${sys.env.getOrElse("SPARK_HOME", "/opt/spark")}/data/mllib/sample_multiclass_classification_data.txt" val data: Dataset[Row] = spark.read.format("libsvm").load(dataPath) data.show() val classifier = new LightGBMClassifier() classifier.setLabelCol("label") classifier.setFeaturesCol("features") // 设置Category型的特征列, 这里我们假设第0列和第1列是Category型的特征列 // 示例数据中不包含Category型的特征列,这里需注释这行代码,仅为演示 // classifier.setCategoricalSlotIndexes(Array(0, 1)) // classifier.setCategoricalSlotNames(Array.apply("enumCol1", "enumCol2")) // 将数据分为训练集和验证集 val Array(trainData, validData) = data.randomSplit(Array(0.6, 0.4)) val model = classifier.fit(trainData) model.saveNativeModel("oss://bucket_name/model", overwrite = true) // Run predictions on the validation data val predictions = model.transform(validData) // Show the prediction results predictions.show() // print accuracy val accuracy = predictions.filter("label == prediction").count().toDouble / predictions.count() println(s"Accuracy: $accuracy") System.exit(0) } }
import os from pyspark.sql import SparkSession if __name__ == '__main__': # init spark spark = SparkSession.builder.appName("lightgbm_spark_train").getOrCreate() # read data f = os.environ.get("SPARK_HOME") + "/data/mllib/sample_multiclass_classification_data.txt" df = spark.read.format("libsvm").load(f) # split data train, test = df.randomSplit([0.8, 0.2]) # train model from synapse.ml.lightgbm import LightGBMClassifier model = LightGBMClassifier(learningRate=0.3, numIterations=20, numLeaves=4).fit(train) # predict prediction = model.transform(test) prediction.show() # stop spark spark.stop()
如果是Scala程序,则需要将其打成Jar包。Python文件不需要打包,可跳过此步骤。
将打好的Jar包或.py的文件上传至OSS。
步骤三:提交Spark作业
登录云原生数据仓库AnalyticDB MySQL控制台,在左上角选择集群所在地域。在左侧导航栏,单击集群列表。在集群列表上方,选择产品系列,然后单击目标集群ID。
在左侧导航栏,单击
。选择Job型资源组和Spark作业类型。本文示例为Batch。
根据步骤二中编写的应用程序,在编辑器中输入以下代码后,单击立即执行。
Scala程序Python程序{ "name": "LightdbmDemo", "file": "oss://testBucketName/original-LightgbmDemo-1.0.jar", "jars": "oss://testBucketName/LightgbmDemo-1.0.jar", "ClassName": "com.aliyun.adb.GBDTdemo", "conf": { "spark.driver.resourceSpec": "large", "spark.executor.instances": 2, "spark.executor.resourceSpec": "medium", "spark.adb.version": "3.5" } }
{ "name": "CatBoostDemo", "file": "oss://testBucketName/GBDT/lightgbm_spark_20241227.py", "jars": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar", "pyFiles": "oss://testBucketName/GBDT/GDBTDemo-1.0-SNAPSHOT.jar", "conf": { "spark.driver.resourceSpec": "large", "spark.executor.instances": 2, "spark.executor.resourceSpec": "medium", "spark.adb.version": "3.5" } }
本文Python应用程序需调用Scala,因此需要在Spark Jar作业代码中添加
jars
参数,并指向Maven依赖所在的OSS路径。参数说明:
参数
是否必填
说明
参数
是否必填
说明
name
否
Spark应用名称。
file
是
Scala:步骤二中编写的Scala程序所在的OSS路径。
Python:步骤二中编写的Python程序所在的OSS路径。
jars
是
步骤一中准备的Maven依赖所在的OSS路径。
ClassName
条件必填
Scala程序入口类名称。
pyFiles
条件必填
步骤一中准备的Maven依赖所在的OSS路径。
spark.adb.version
是
Spark版本,此处必须显式指定为3.5。
其他conf参数
否
与开源Spark中的配置项基本一致,参数格式为
key: value
形式,多个参数之间以英文逗号(,)分隔。与开源Spark用法不一致的配置参数及AnalyticDB for MySQL特有的配置参数,请参见Spark应用配置参数说明。(可选)在应用列表中,单击目标作业操作列的日志,在日志中可以查看返回结果。
- 本页导读 (1)
- 前提条件
- 操作步骤
- 步骤一:准备Maven依赖并上传至OSS
- 步骤二:编写应用程序并上传至OSS
- 步骤三:提交Spark作业