一、概述 XGBoost是一種基於決策樹的集成學習演算法,它在處理結構化數據方面表現優異。相比其他演算法,XGBoost能夠處理大量特征和樣本,並且支持通過正則化控制模型的複雜度。XGBoost也可以自動進行特征選擇並對缺失值進行處理。 二、代碼實現步驟 1、導入相關庫 import org.apach ...
一、概述
XGBoost是一種基於決策樹的集成學習演算法,它在處理結構化數據方面表現優異。相比其他演算法,XGBoost能夠處理大量特征和樣本,並且支持通過正則化控制模型的複雜度。XGBoost也可以自動進行特征選擇並對缺失值進行處理。
二、代碼實現步驟
1、導入相關庫
import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.evaluation.RegressionEvaluator; import org.apache.spark.ml.feature.VectorAssembler; import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SparkSession;
2、載入數據
SparkSession spark = SparkSession.builder().appName("XGBoost").master("local[*]").getOrCreate();
DataFrame data = spark.read().option("header", "true").option("inferSchema", "true").csv("data.csv");
3、準備特征向量
String[] featureCols = data.columns(); featureCols = Arrays.copyOfRange(featureCols, 0, featureCols.length - 1); VectorAssembler assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features"); DataFrame inputData = assembler.transform(data).select("features", "output"); inputData.show(false);
4、劃分訓練集和測試集
double[] weights = {0.7, 0.3}; DataFrame[] splitData = inputData.randomSplit(weights); DataFrame train = splitData[0]; DataFrame test = splitData[1];
5、定義XGBoost模型
GBTRegressor gbt = new GBTRegressor() .setLabelCol("output") .setFeaturesCol("features") .setMaxIter(100) .setStepSize(0.1) .setMaxDepth(6) .setLossType("squared") .setFeatureSubsetStrategy("auto");
6、構建管道
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{gbt});
7、訓練模型
GBTRegressionModel model = (GBTRegressionModel) pipeline.fit(train).stages()[0];
8、進行預測並評估模型
DataFrame predictions = model.transform(test); predictions.show(false); RegressionEvaluator evaluator = new RegressionEvaluator() .setMetricName("rmse") .setLabelCol("output") .setPredictionCol("prediction"); double rmse = evaluator.evaluate(predictions); System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);
以上就是Java語言中基於SparkML的XGBoost演算法實現的示例代碼。需要註意的是,這裡使用了GBTRegressor作為XGBoost的實現方式,但是也可以使用其他實現方式,例如XGBoostRegressor或者XGBoostClassification。
三、完整代碼
import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.evaluation.RegressionEvaluator; import org.apache.spark.ml.feature.VectorAssembler; import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SparkSession; import java.util.Arrays; public class XGBoostExample { public static void main(String[] args) { SparkSession spark = SparkSession.builder().appName("XGBoost").master("local[*]").getOrCreate(); // 載入數據 DataFrame data = spark.read().option("header", "true").option("inferSchema", "true").csv("data.csv"); data.printSchema(); data.show(false); // 準備特征向量 String[] featureCols = data.columns(); featureCols = Arrays.copyOfRange(featureCols, 0, featureCols.length - 1); VectorAssembler assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features"); DataFrame inputData = assembler.transform(data).select("features", "output"); inputData.show(false); // 劃分訓練集和測試集 double[] weights = {0.7, 0.3}; DataFrame[] splitData = inputData.randomSplit(weights); DataFrame train = splitData[0]; DataFrame test = splitData[1]; // 定義XGBoost模型 GBTRegressor gbt = new GBTRegressor() .setLabelCol("output") .setFeaturesCol("features") .setMaxIter(100) .setStepSize(0.1) .setMaxDepth(6) .setLossType("squared") .setFeatureSubsetStrategy("auto"); // 構建管道 Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{gbt}); // 訓練模型 GBTRegressionModel model = (GBTRegressionModel) pipeline.fit(train).stages()[0]; // 進行預測並評估模型 DataFrame predictions = model.transform(test); predictions.show(false); RegressionEvaluator evaluator = new RegressionEvaluator() .setMetricName("rmse") .setLabelCol("output") .setPredictionCol("prediction"); double rmse = evaluator.evaluate(predictions); System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); spark.stop(); } }
在運行代碼之前需要將數據文件data.csv
放置到程式所在目錄下,以便載入數據。另外,需要將代碼中的相關路徑和參數按照實際情況進行修改。