Skip to content

Commit

Permalink
[XGBoost] Add xgboostclassifier predict example in scala (#3693)
Browse files Browse the repository at this point in the history
* Add xgboostclassifier predict example in scala with Iris Data Set from UCI
* Local test with spark 2.4 and spark 3.1
* Update README.md
* Add pr test scripts.
  • Loading branch information
piaolaidelangman authored Dec 10, 2021
1 parent 45c02cd commit 9fec6fd
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 6 deletions.
30 changes: 30 additions & 0 deletions apps/run-scala-app-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,40 @@ echo "#App[Model-inference-example] Test 5.1: model-inference-flink: Image Class

./flink-1.7.2/bin/stop-cluster.sh

echo "# Test 6.1 dllib nnframes: XGBoostClassifierTrainExample"
#timer
mkdir /tmp/data
wget $FTP_URI/analytics-zoo-data/iris.data -P /tmp/data
start=$(date "+%s")
${SPARK_HOME}/bin/spark-submit \
--master local[4] \
--conf spark.task.cpus=2 \
--class com.intel.analytics.bigdl.dllib.examples.nnframes.xgboost.xgbClassifierTrainingExample \
${BIGDL_ROOT}/scala/dllib/target/bigdl-dllib-spark_2.4.6-0.14.0-SNAPSHOT-jar-with-dependencies.jar \
/tmp/data/iris.data 2 200 /tmp/data/xgboost_model
now=$(date "+%s")
time9=$((now-start))
echo "#App[Model-inference-example] Test 6.1: dllib nnframes: XGBoostClassifierTrainExample time used:$time9 seconds"

echo "# Test 6.2 dllib nnframes: XGBoostClassifierPredictExample"
#timer
start=$(date "+%s")
${SPARK_HOME}/bin/spark-submit \
--master local[4] \
--conf spark.task.cpus=2 \
--class com.intel.analytics.bigdl.dllib.examples.nnframes.xgboost.xgbClassifierPredictExample \
${BIGDL_ROOT}/scala/dllib/target/bigdl-dllib-spark_2.4.6-0.14.0-SNAPSHOT-jar-with-dependencies.jar \
/tmp/data/iris.data /tmp/data/xgboost_model
now=$(date "+%s")
time10=$((now-start))
echo "#App[Model-inference-example] Test 6.2: dllib nnframes: XGBoostClassifierPredictExample time used:$time10 seconds"

echo "#2 text-classification-training time used:$time2 seconds"
echo "#3.1 text-classification-inference:SimpleDriver time used:$time3 seconds"
#echo "#3.2 text-classification-inference:WebServiceDriver time used:$time4 seconds"
echo "#4.1 recommendation-inference:SimpleScalaDriver time used:$time5 seconds"
echo "#4.2 recommendation-inference:SimpleDriver time used:$time6 seconds"
echo "#5.1 model-inference-flink:Text Classification time used:$time7 seconds"
echo "#5.2 model-inference-flink:Image Classification time used:$time8 seconds"
echo "#6.1: dllib nnframes: XGBoostClassifierTrainExample time used:$time9 seconds"
echo "#6.2: dllib nnframes: XGBoostClassifierPredictExample time used:$time10 seconds"
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
# XGBoostClassifier Train Example
# Prepare

## Environment
- Spark 2.4
- Spark 2.4 or Spark 3.1
- BigDL 2.0

## Data Prepare

### BigDL nightly build

You can download [here](https://bigdl.readthedocs.io/en/latest/doc/release.html).
You will get jar `bigdl-dllib-spark_2.4.6-0.14.0-build_time-jar-with-dependencies.jar`.
For spark 2.4 you need `bigdl-dllib-spark_2.4.6-0.14.0-build_time-jar-with-dependencies.jar` or `bigdl-dllib-spark_3.1.2-0.14.0-build_time-jar-with-dependencies.jar` for spark 3.1 .


### UCI iris.data

You can download iris.data [here](https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data).

# XGBoostClassifier Train Example
## Run:

command:
```
spark-submit \
--master local[2] \
--conf spark.task.cpus=2 \
--conf spark.task.cpus=2 \
--class com.intel.analytics.bigdl.dllib.examples.nnframes.xgboost.xgbClassifierTrainingExample \
/path/to/BigDL/scala/dllib/target/bigdl-dllib-spark_2.4.6-0.14.0-SNAPSHOT-jar-with-dependencies.jar \
/path/to/bigdl-dllib-spark_2.4.6-0.14.0-SNAPSHOT-jar-with-dependencies.jar \
/path/to/iris.data 2 100 /path/to/model/saved
```

Expand All @@ -46,4 +48,43 @@ parameters:
- num_round : Int
- path_to_model_saved : String

note: make sure num_threads is larger than spark.task.cpus.
**note: make sure num_threads is larger than spark.task.cpus.**

# XGBoostClassifier Predict Example
## Run:
```
spark-submit \
--master local[4] \
--conf spark.task.cpus=2 \
--class com.intel.analytics.bigdl.dllib.examples.nnframes.xgboost.xgbClassifierPredictExample \
/path/to/bigdl-dllib-spark_2.4.6-0.14.0-SNAPSHOT-jar-with-dependencies.jar \
/path/to/iris.data 2 100 /path/to/model/saved
```
You will get output like:
```
+------------+-----------+------------+-----------+-----------+--------------------+--------------------+----------+
|sepal length|sepal width|petal length|petal width| class| rawPrediction| probability|prediction|
+------------+-----------+------------+-----------+-----------+--------------------+--------------------+----------+
| 5.1| 3.5| 1.4| 0.2|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0|
| 4.9| 3.0| 1.4| 0.2|Iris-setosa|[2.94163084030151...|[0.98863482475280...| 0.0|
| 4.7| 3.2| 1.3| 0.2|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0|
| 4.6| 3.1| 1.5| 0.2|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0|
| 5.0| 3.6| 1.4| 0.2|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0|
| 5.4| 3.9| 1.7| 0.4|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0|
| 4.6| 3.4| 1.4| 0.3|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0|
| 5.0| 3.4| 1.5| 0.2|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0|
| 4.4| 2.9| 1.4| 0.2|Iris-setosa|[2.94163084030151...|[0.97911602258682...| 0.0|
| 4.9| 3.1| 1.5| 0.1|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0|
| 5.4| 3.7| 1.5| 0.2|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0|
| 4.8| 3.4| 1.6| 0.2|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0|
| 4.8| 3.0| 1.4| 0.1|Iris-setosa|[2.94163084030151...|[0.98863482475280...| 0.0|
| 4.3| 3.0| 1.1| 0.1|Iris-setosa|[2.94163084030151...|[0.98863482475280...| 0.0|
| 5.8| 4.0| 1.2| 0.2|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0|
| 5.7| 4.4| 1.5| 0.4|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0|
| 5.4| 3.9| 1.3| 0.4|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0|
| 5.1| 3.5| 1.4| 0.3|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0|
| 5.7| 3.8| 1.7| 0.3|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0|
| 5.1| 3.8| 1.5| 0.3|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0|
+------------+-----------+------------+-----------+-----------+--------------------+--------------------+----------+
only showing top 20 rows
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright 2016 The BigDL Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.intel.analytics.bigdl.dllib.examples.nnframes.xgboost

import com.intel.analytics.bigdl.dllib.NNContext
import com.intel.analytics.bigdl.dllib.nnframes.XGBClassifierModel

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
object xgbClassifierPredictExample {
def main(args: Array[String]): Unit = {
if (args.length < 2) {
println("Usage: program input_path model_path")
sys.exit(1)
}
val input_path = args(0)
val model_path = args(1)

val sc = NNContext.initNNContext()
val spark = SQLContext.getOrCreate(sc)

val schema = new StructType(Array(
StructField("sepal length", DoubleType, true),
StructField("sepal width", DoubleType, true),
StructField("petal length", DoubleType, true),
StructField("petal width", DoubleType, true),
StructField("class", StringType, true)))
val df = spark.read.schema(schema).csv(input_path)

val model = XGBClassifierModel.load(model_path)
model.setFeaturesCol(Array("sepal length", "sepal width", "petal length", "petal width"))

val results = model.transform(df)
results.show()

sc.stop()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,11 @@ object XGBClassifierModel {
def load(path: String, numClass: Int): XGBClassifierModel = {
new XGBClassifierModel(XGBoostHelper.load(path, numClass))
}

def load(path: String): XGBClassifierModel = {
new XGBClassifierModel(XGBoostClassificationModel.load(path))
}

}

/**
Expand Down

0 comments on commit 9fec6fd

Please sign in to comment.