A Full Integration of XGBoost and Apache Spark

October 26, 2016

(This article was first published on DMLC, and kindly contributed to R-bloggers)


On March 2016, we released the first version of XGBoost4J, which is a set of packages providing Java/Scala interfaces of XGBoost and the integration with prevalent JVM-based distributed data processing platforms, like Spark/Flink.

The integrations with Spark/Flink, a.k.a. XGBoost4J-Spark and XGBoost-Flink, receive the tremendous positive feedbacks from the community. It enables users to build a unified pipeline, embedding XGBoost into the data processing system based on the widely-deployed frameworks like Spark. The following figure shows the general architecture of such a pipeline with the first version of XGBoost4J-Spark, where the data processing is based on the low-level Resilient Distributed Dataset (RDD) abstraction.

XGBoost4J Architecture

In the last months, we have a lot of communication with the users and gain the deeper understanding of the users' latest usage scenario and requirements:

  • XGBoost keeps gaining more and more deployments in the production environment and the adoption in machine learning competitions Link.

  • While Spark is still the mainstream data processing tool in most of scenarios, more and more users are porting their RDD-based Spark programs to DataFrame/Dataset APIs for the well-designed interfaces to manipulate structured data and the significant performance improvement.

  • Spark itself has presented a clear roadmap that DataFrame/Dataset would be the base of the latest and future features, e.g. latest version of ML pipeline and Structured Streaming.

Based on these feedbacks from the users, we observe a gap between the original RDD-based XGBoost4J-Spark and the users' latest usage scenario as well as the future direction of Spark ecosystem. To fill this gap, we start working on the integration of XGBoost and Spark's DataFrame/Dataset abstraction in September. In this blog, we will introduce the latest version of XGBoost4J-Spark which allows the user to work with DataFrame/Dataset directly and embed XGBoost to Spark's ML pipeline seamlessly.

A Full Integration of XGBoost and DataFrame/Dataset

The following figure illustrates the new pipeline architecture with the latest XGBoost4J-Spark.

XGBoost4J New Architecture

Being different with the previous version, users are able to use both low- and high-level memory abstraction in Spark, i.e. RDD and DataFrame/Dataset. The DataFrame/Dataset abstraction grants the user to manipulate structured datasets and utilize the built-in routines in Spark or User Defined Functions (UDF) to explore the value distribution in columns before they feed data into the machine learning phase in the pipeline. In the following example, the structured sales records can be saved in a JSON file, parsed as DataFrame through Spark's API and feed to train XGBoost model in two lines of Scala code.

// load sales records saved in json files
val salesDF = spark.read.json("sales.json")
// call XGBoost API to train with the DataFrame-represented training set
val xgboostModel = XGBoost.trainWithDataFrame(
      salesDF, paramMap, numRound, nWorkers, useExternalMemory)

By integrating with DataFrame/Dataset, XGBoost4J-Spark not only enables users to call DataFrame/Dataset APIs directly but also make DataFrame/Dataset-based Spark features available to XGBoost users, e.g. ML Package.

Integration with ML Package

ML package of Spark provides a set of convenient tools for feature extraction/transformation/selection. Additionally, with the model selection tool in ML package, users can select the best model through an automatic parameter searching process which is defined with through ML package APIs. After integrating with DataFrame/Dataset abstraction, these charming features in ML package are also available to XGBoost users.

Feature Extraction/Transformation/Selection

The following example shows a feature transformer which converts the string-typed storeType feature to the numeric storeTypeIndex. The transformed DataFrame is then fed to train XGBoost model.

import org.apache.spark.ml.feature.StringIndexer

// load sales records saved in json files
val salesDF = spark.read.json("sales.json")

// transfrom the string-represented storeType feature to numeric storeTypeIndex
val indexer = new StringIndexer()
// drop the extra column
val indexed = indexer.fit(salesDF).transform(df).drop("storeType")

// use the transformed dataframe as training dataset
val xgboostModel = XGBoost.trainWithDataFrame(
      indexed, paramMap, numRound, nWorkers, useExternalMemory)


Spark ML package allows the user to build a complete pipeline from feature extraction/transformation/selection to model training. We integrate XGBoost with ML package and make it feasible to embed XGBoost into such a pipeline seamlessly. The following example shows how to build such a pipeline consisting of feature transformers and the XGBoost estimator.

import org.apache.spark.ml.feature.StringIndexer

// load sales records saved in json files
val salesDF = spark.read.json("sales.json")

// transfrom the string-represented storeType feature to numeric storeTypeIndex
val indexer = new StringIndexer()

// assemble the columns in dataframe into a vector
val vectorAssembler = new VectorAssembler()
      .setInputCols(Array("storeId", "storeTypeIndex", ...))

// construct the pipeline       
val pipeline = new Pipeline().setStages(
      Array(storeTypeIndexer, ..., vectorAssembler, new XGBoostEstimator(Map[String, Any]("num_rounds" -> 100)))

// use the transformed dataframe as training dataset
val xgboostModel = pipeline.fit(salesDF)

// predict with the trained model
val salesTestDF = spark.read.json("sales_test.json")
val salesRecordsWithPred = xgboostModel.transform(salesTestDF)

Model Selection

The most critical operation to maximize the power of XGBoost is to select the optimal parameters for the model. Tuning parameters manually is a tedious and labor-consuming process. With the latest version of XGBoost4J-Spark, we can utilize the Spark model selecting tool to automate this process. The following example shows the code snippet utilizing TrainValidationSplit and RegressionEvaluator to search the optimal combination of two XGBoost parameters, max_depth and eta. The model producing the minimum cost function value defined by RegressionEvaluator is selected and used to generate the prediction for the test set.

// create XGBoostEstimator 
val xgbEstimator = new XGBoostEstimator(xgboostParam).setFeaturesCol("features").
val paramGrid = new ParamGridBuilder()
      .addGrid(xgbEstimator.maxDepth, Array(5, 6))
      .addGrid(xgbEstimator.eta, Array(0.1, 0.4))
val tv = new TrainValidationSplit()
      .setEvaluator(new RegressionEvaluator().setLabelCol("sales"))
val salesTestDF = spark.read.json("sales_test.json")
val salesRecordsWithPred = xgboostModel.transform(salesTestDF)


Through the latest XGBoost4J-Spark, XGBoost users can build a more efficient data processing pipeline which works with DataFrame/Dataset APIs to handle the structured data with the excellent performance, and simultaneously embrace the powerful XGBoost to explore the insights from the dataset and transform this insight into action. Additionally, XGBoost4J-Spark seamlessly connect XGBoost with Spark ML package which makes the job of feature extraction/transformation/selection and parameter model much easier than before.

The latest version of XGBoost4J-Spark has been available in the GitHub Repository, and the latest API docs are in here.

Portable Machine Learning Systems

XGBoost is one of the projects incubated by Distributed Machine Learning Community (DMLC), which also creates several other popular projects on machine learning systems (Link), e.g. one of the most popular deep learning frameworks, MXNet. We strongly believe that machine learning solution should not be restricted to certain language or certain platform. We realize this design philosophy in several projects, like XGBoost and MXNet. We are willing to see more contributions from the community in this direction.

Further Readings

If you are interested in knowing more about XGBoost, you can find rich resources in

To leave a comment for the author, please follow the link and comment on their blog: DMLC.

R-bloggers.com offers daily e-mail updates about R news and tutorials on topics such as: Data science, Big Data, R jobs, visualization (ggplot2, Boxplots, maps, animation), programming (RStudio, Sweave, LaTeX, SQL, Eclipse, git, hadoop, Web Scraping) statistics (regression, PCA, time series, trading) and more...

If you got this far, why not subscribe for updates from the site? Choose your flavor: e-mail, twitter, RSS, or facebook...

Comments are closed.


Never miss an update!
Subscribe to R-bloggers to receive
e-mails with the latest R posts.
(You will not see this message again.)

Click here to close (This popup will not appear again)