本文通过示例为您介绍,如何基于开源XGBoost在Data Science集群进行分布式训练。您可以根据本文示例自行适配您的数据,修改提供的代码示例以进行定制化建模。
前提条件
- 开发工具
- 本地安装了Java JDK 8。
- 本地安装了Maven 3.x。
- 本地安装了用于Java或Scala开发的IDE,推荐IntelliJ IDEA,且已配置完成JDK和Maven环境。
- 已创建DataScience集群,详情请参见创建集群。
- 下载dsdemo代码:请已创建DataScience集群的用户,使用钉钉搜索钉钉群号32497587加入钉钉群以获取dsdemo代码。
背景信息
分布式训练基于Spark on Yarn。示例代码使用Scala语言。
代码示例
您可以在IDE中查看SparkTraining.java,代码如下。
object SparkTraining {
def main(args: Array[String]): Unit = {
if (args.length < 1) {
// scalastyle:off
println("Usage: program input_path")
sys.exit(1)
}
val spark = SparkSession.builder().getOrCreate()
val inputPath = args(0)
val schema = new StructType(Array(
StructField("sepal length", DoubleType, true),
StructField("sepal width", DoubleType, true),
StructField("petal length", DoubleType, true),
StructField("petal width", DoubleType, true),
StructField("class", StringType, true)))
val rawInput = spark.read.schema(schema).csv(inputPath)
// transform class to index to make xgboost happy
val stringIndexer = new StringIndexer()
.setInputCol("class")
.setOutputCol("classIndex")
.fit(rawInput)
val labelTransformed = stringIndexer.transform(rawInput).drop("class")
// compose all feature columns as vector
val vectorAssembler = new VectorAssembler().
setInputCols(Array("sepal length", "sepal width", "petal length", "petal width")).
setOutputCol("features")
val xgbInput = vectorAssembler.transform(labelTransformed).select("features",
"classIndex")
val Array(train, eval1, eval2, test) = xgbInput.randomSplit(Array(0.6, 0.2, 0.1, 0.1))
/**
* setup "timeout_request_workers" -> 60000L to make this application if it cannot get enough resources
* to get 2 workers within 60000 ms
*
* setup "checkpoint_path" -> "/checkpoints" and "checkpoint_interval" -> 2 to save checkpoint for every
* two iterations
*/
val xgbParam = Map("eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "multi:softprob",
"num_class" -> 3,
"num_round" -> 100,
"num_workers" -> 2,
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
val xgbClassifier = new XGBoostClassifier(xgbParam).
setFeaturesCol("features").
setLabelCol("classIndex")
val xgbClassificationModel = xgbClassifier.fit(train)
val results = xgbClassificationModel.transform(test)
results.show()
}
}
运行代码
您可以通过以下两种方式运行代码:
- 本地Local模式
#!/bin/sh hadoop fs -put -f iris.csv hdfs://emr-header-1:9000/ spark-submit --master 'local[8]' \ --class ml.dmlc.xgboost4j.scala.example.spark.SparkTraining xgboosttraining-0.1-SNAPSHOT.jar \ hdfs://emr-header-1:9000/iris.csv
- 分布式Yarn-Cluster模式
#!/bin/sh hadoop fs -put -f iris.csv hdfs://emr-header-1:9000/ spark-submit --master yarn-cluster \ --class ml.dmlc.xgboost4j.scala.example.spark.SparkTraining xgboosttraining-0.1-SNAPSHOT.jar \ hdfs://emr-header-1:9000/iris.csv