From 9fec6fd7cfc2c282c012765f5b53e534890030e5 Mon Sep 17 00:00:00 2001 From: Diankun An <54262787+piaolaidelangman@users.noreply.github.com> Date: Fri, 10 Dec 2021 12:37:14 +0800 Subject: [PATCH] [XGBoost] Add xgboostclassifier predict example in scala (#3693) * Add xgboostclassifier predict example in scala with Iris Data Set from UCI * Local test with spark 2.4 and spark 3.1 * Update README.md * Add pr test scripts. --- apps/run-scala-app-test.sh | 30 +++++++++++ .../dllib/example/nnframes/xgboost/README.md | 53 ++++++++++++++++--- .../xgboost/xgbClassifierPredictExample.scala | 52 ++++++++++++++++++ .../bigdl/dllib/nnframes/NNClassifier.scala | 5 ++ 4 files changed, 134 insertions(+), 6 deletions(-) create mode 100644 scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/example/nnframes/xgboost/xgbClassifierPredictExample.scala diff --git a/apps/run-scala-app-test.sh b/apps/run-scala-app-test.sh index 46c235b4150..21930fd4a76 100644 --- a/apps/run-scala-app-test.sh +++ b/apps/run-scala-app-test.sh @@ -190,6 +190,34 @@ echo "#App[Model-inference-example] Test 5.1: model-inference-flink: Image Class ./flink-1.7.2/bin/stop-cluster.sh +echo "# Test 6.1 dllib nnframes: XGBoostClassifierTrainExample" +#timer +mkdir /tmp/data +wget $FTP_URI/analytics-zoo-data/iris.data -P /tmp/data +start=$(date "+%s") +${SPARK_HOME}/bin/spark-submit \ + --master local[4] \ + --conf spark.task.cpus=2 \ + --class com.intel.analytics.bigdl.dllib.examples.nnframes.xgboost.xgbClassifierTrainingExample \ + ${BIGDL_ROOT}/scala/dllib/target/bigdl-dllib-spark_2.4.6-0.14.0-SNAPSHOT-jar-with-dependencies.jar \ + /tmp/data/iris.data 2 200 /tmp/data/xgboost_model +now=$(date "+%s") +time9=$((now-start)) +echo "#App[Model-inference-example] Test 6.1: dllib nnframes: XGBoostClassifierTrainExample time used:$time9 seconds" + +echo "# Test 6.2 dllib nnframes: XGBoostClassifierPredictExample" +#timer +start=$(date "+%s") +${SPARK_HOME}/bin/spark-submit \ + --master local[4] \ + --conf spark.task.cpus=2 \ + --class com.intel.analytics.bigdl.dllib.examples.nnframes.xgboost.xgbClassifierPredictExample \ + ${BIGDL_ROOT}/scala/dllib/target/bigdl-dllib-spark_2.4.6-0.14.0-SNAPSHOT-jar-with-dependencies.jar \ + /tmp/data/iris.data /tmp/data/xgboost_model +now=$(date "+%s") +time10=$((now-start)) +echo "#App[Model-inference-example] Test 6.2: dllib nnframes: XGBoostClassifierPredictExample time used:$time10 seconds" + echo "#2 text-classification-training time used:$time2 seconds" echo "#3.1 text-classification-inference:SimpleDriver time used:$time3 seconds" #echo "#3.2 text-classification-inference:WebServiceDriver time used:$time4 seconds" @@ -197,3 +225,5 @@ echo "#4.1 recommendation-inference:SimpleScalaDriver time used:$time5 seconds" echo "#4.2 recommendation-inference:SimpleDriver time used:$time6 seconds" echo "#5.1 model-inference-flink:Text Classification time used:$time7 seconds" echo "#5.2 model-inference-flink:Image Classification time used:$time8 seconds" +echo "#6.1: dllib nnframes: XGBoostClassifierTrainExample time used:$time9 seconds" +echo "#6.2: dllib nnframes: XGBoostClassifierPredictExample time used:$time10 seconds" diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/example/nnframes/xgboost/README.md b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/example/nnframes/xgboost/README.md index 2f706101948..fdb58c64933 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/example/nnframes/xgboost/README.md +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/example/nnframes/xgboost/README.md @@ -1,7 +1,7 @@ -# XGBoostClassifier Train Example +# Prepare ## Environment -- Spark 2.4 +- Spark 2.4 or Spark 3.1 - BigDL 2.0 ## Data Prepare @@ -9,21 +9,23 @@ ### BigDL nightly build You can download [here](https://bigdl.readthedocs.io/en/latest/doc/release.html). -You will get jar `bigdl-dllib-spark_2.4.6-0.14.0-build_time-jar-with-dependencies.jar`. +For spark 2.4 you need `bigdl-dllib-spark_2.4.6-0.14.0-build_time-jar-with-dependencies.jar` or `bigdl-dllib-spark_3.1.2-0.14.0-build_time-jar-with-dependencies.jar` for spark 3.1 . + ### UCI iris.data You can download iris.data [here](https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data). +# XGBoostClassifier Train Example ## Run: command: ``` spark-submit \ --master local[2] \ - --conf spark.task.cpus=2 \ + --conf spark.task.cpus=2 \ --class com.intel.analytics.bigdl.dllib.examples.nnframes.xgboost.xgbClassifierTrainingExample \ - /path/to/BigDL/scala/dllib/target/bigdl-dllib-spark_2.4.6-0.14.0-SNAPSHOT-jar-with-dependencies.jar \ + /path/to/bigdl-dllib-spark_2.4.6-0.14.0-SNAPSHOT-jar-with-dependencies.jar \ /path/to/iris.data 2 100 /path/to/model/saved ``` @@ -46,4 +48,43 @@ parameters: - num_round : Int - path_to_model_saved : String -note: make sure num_threads is larger than spark.task.cpus. +**note: make sure num_threads is larger than spark.task.cpus.** + +# XGBoostClassifier Predict Example +## Run: +``` +spark-submit \ + --master local[4] \ + --conf spark.task.cpus=2 \ + --class com.intel.analytics.bigdl.dllib.examples.nnframes.xgboost.xgbClassifierPredictExample \ + /path/to/bigdl-dllib-spark_2.4.6-0.14.0-SNAPSHOT-jar-with-dependencies.jar \ + /path/to/iris.data 2 100 /path/to/model/saved +``` +You will get output like: +``` ++------------+-----------+------------+-----------+-----------+--------------------+--------------------+----------+ +|sepal length|sepal width|petal length|petal width| class| rawPrediction| probability|prediction| ++------------+-----------+------------+-----------+-----------+--------------------+--------------------+----------+ +| 5.1| 3.5| 1.4| 0.2|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0| +| 4.9| 3.0| 1.4| 0.2|Iris-setosa|[2.94163084030151...|[0.98863482475280...| 0.0| +| 4.7| 3.2| 1.3| 0.2|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0| +| 4.6| 3.1| 1.5| 0.2|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0| +| 5.0| 3.6| 1.4| 0.2|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0| +| 5.4| 3.9| 1.7| 0.4|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0| +| 4.6| 3.4| 1.4| 0.3|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0| +| 5.0| 3.4| 1.5| 0.2|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0| +| 4.4| 2.9| 1.4| 0.2|Iris-setosa|[2.94163084030151...|[0.97911602258682...| 0.0| +| 4.9| 3.1| 1.5| 0.1|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0| +| 5.4| 3.7| 1.5| 0.2|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0| +| 4.8| 3.4| 1.6| 0.2|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0| +| 4.8| 3.0| 1.4| 0.1|Iris-setosa|[2.94163084030151...|[0.98863482475280...| 0.0| +| 4.3| 3.0| 1.1| 0.1|Iris-setosa|[2.94163084030151...|[0.98863482475280...| 0.0| +| 5.8| 4.0| 1.2| 0.2|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0| +| 5.7| 4.4| 1.5| 0.4|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0| +| 5.4| 3.9| 1.3| 0.4|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0| +| 5.1| 3.5| 1.4| 0.3|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0| +| 5.7| 3.8| 1.7| 0.3|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0| +| 5.1| 3.8| 1.5| 0.3|Iris-setosa|[2.94163084030151...|[0.99256813526153...| 0.0| ++------------+-----------+------------+-----------+-----------+--------------------+--------------------+----------+ +only showing top 20 rows +``` diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/example/nnframes/xgboost/xgbClassifierPredictExample.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/example/nnframes/xgboost/xgbClassifierPredictExample.scala new file mode 100644 index 00000000000..246bcbeef66 --- /dev/null +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/example/nnframes/xgboost/xgbClassifierPredictExample.scala @@ -0,0 +1,52 @@ +/* + * Copyright 2016 The BigDL Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.intel.analytics.bigdl.dllib.examples.nnframes.xgboost + +import com.intel.analytics.bigdl.dllib.NNContext +import com.intel.analytics.bigdl.dllib.nnframes.XGBClassifierModel + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} +object xgbClassifierPredictExample { + def main(args: Array[String]): Unit = { + if (args.length < 2) { + println("Usage: program input_path model_path") + sys.exit(1) + } + val input_path = args(0) + val model_path = args(1) + + val sc = NNContext.initNNContext() + val spark = SQLContext.getOrCreate(sc) + + val schema = new StructType(Array( + StructField("sepal length", DoubleType, true), + StructField("sepal width", DoubleType, true), + StructField("petal length", DoubleType, true), + StructField("petal width", DoubleType, true), + StructField("class", StringType, true))) + val df = spark.read.schema(schema).csv(input_path) + + val model = XGBClassifierModel.load(model_path) + model.setFeaturesCol(Array("sepal length", "sepal width", "petal length", "petal width")) + + val results = model.transform(df) + results.show() + + sc.stop() + } +} diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nnframes/NNClassifier.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nnframes/NNClassifier.scala index d6cbe8c835c..cba98353707 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nnframes/NNClassifier.scala +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nnframes/NNClassifier.scala @@ -448,6 +448,11 @@ object XGBClassifierModel { def load(path: String, numClass: Int): XGBClassifierModel = { new XGBClassifierModel(XGBoostHelper.load(path, numClass)) } + + def load(path: String): XGBClassifierModel = { + new XGBClassifierModel(XGBoostClassificationModel.load(path)) + } + } /**