本文通过示例为您介绍,如何基于开源XGBoost在Data Science集群进行分布式训练。您可以根据本文示例自行适配您的数据,修改提供的代码示例以进行定制化建模。

前提条件

背景信息

分布式训练基于Spark on Yarn。示例代码使用Scala语言。

代码示例

您可以在IDE中查看SparkTraining.java,代码如下。

object SparkTraining {

  def main(args: Array[String]): Unit = {
    if (args.length < 1) {
      // scalastyle:off
      println("Usage: program input_path")
      sys.exit(1)
    }
    val spark = SparkSession.builder().getOrCreate()
    val inputPath = args(0)
    val schema = new StructType(Array(
      StructField("sepal length", DoubleType, true),
      StructField("sepal width", DoubleType, true),
      StructField("petal length", DoubleType, true),
      StructField("petal width", DoubleType, true),
      StructField("class", StringType, true)))
    val rawInput = spark.read.schema(schema).csv(inputPath)

    // transform class to index to make xgboost happy
    val stringIndexer = new StringIndexer()
      .setInputCol("class")
      .setOutputCol("classIndex")
      .fit(rawInput)
    val labelTransformed = stringIndexer.transform(rawInput).drop("class")
    // compose all feature columns as vector
    val vectorAssembler = new VectorAssembler().
      setInputCols(Array("sepal length", "sepal width", "petal length", "petal width")).
      setOutputCol("features")
    val xgbInput = vectorAssembler.transform(labelTransformed).select("features",
      "classIndex")

    val Array(train, eval1, eval2, test) = xgbInput.randomSplit(Array(0.6, 0.2, 0.1, 0.1))

    /**
     * setup  "timeout_request_workers" -> 60000L to make this application if it cannot get enough resources
     * to get 2 workers within 60000 ms
     *
     * setup "checkpoint_path" -> "/checkpoints" and "checkpoint_interval" -> 2 to save checkpoint for every
     * two iterations
     */
    val xgbParam = Map("eta" -> 0.1f,
      "max_depth" -> 2,
      "objective" -> "multi:softprob",
      "num_class" -> 3,
      "num_round" -> 100,
      "num_workers" -> 2,
      "eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
    val xgbClassifier = new XGBoostClassifier(xgbParam).
      setFeaturesCol("features").
      setLabelCol("classIndex")
    val xgbClassificationModel = xgbClassifier.fit(train)
    val results = xgbClassificationModel.transform(test)
    results.show()
  }
}

运行代码

您可以通过以下两种方式运行代码:
  • 本地Local模式
    #!/bin/sh
    hadoop fs -put -f iris.csv hdfs://emr-header-1:9000/
    spark-submit --master 'local[8]' \
    --class ml.dmlc.xgboost4j.scala.example.spark.SparkTraining xgboosttraining-0.1-SNAPSHOT.jar \
    hdfs://emr-header-1:9000/iris.csv
  • 分布式Yarn-Cluster模式
    #!/bin/sh
    hadoop fs -put -f iris.csv hdfs://emr-header-1:9000/
    spark-submit --master yarn-cluster \
    --class ml.dmlc.xgboost4j.scala.example.spark.SparkTraining xgboosttraining-0.1-SNAPSHOT.jar \
    hdfs://emr-header-1:9000/iris.csv