Java語言在Spark3.2.4集群中使用Spark MLlib庫完成XGboost演算法

来源:https://www.cnblogs.com/wxm2270/archive/2023/04/12/17310277.html
-Advertisement-
Play Games

一、概述 XGBoost是一種基於決策樹的集成學習演算法,它在處理結構化數據方面表現優異。相比其他演算法,XGBoost能夠處理大量特征和樣本,並且支持通過正則化控制模型的複雜度。XGBoost也可以自動進行特征選擇並對缺失值進行處理。 二、代碼實現步驟 1、導入相關庫 import org.apach ...


一、概述

XGBoost是一種基於決策樹的集成學習演算法,它在處理結構化數據方面表現優異。相比其他演算法,XGBoost能夠處理大量特征和樣本,並且支持通過正則化控制模型的複雜度。XGBoost也可以自動進行特征選擇並對缺失值進行處理。

二、代碼實現步驟

1、導入相關庫

import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor};
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SparkSession;

2、載入數據

SparkSession spark = SparkSession.builder().appName("XGBoost").master("local[*]").getOrCreate();
DataFrame data = spark.read().option("header", "true").option("inferSchema", "true").csv("data.csv");

3、準備特征向量

String[] featureCols = data.columns();
featureCols = Arrays.copyOfRange(featureCols, 0, featureCols.length - 1);
VectorAssembler assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features");
DataFrame inputData = assembler.transform(data).select("features", "output");
inputData.show(false);

4、劃分訓練集和測試集

double[] weights = {0.7, 0.3};
DataFrame[] splitData = inputData.randomSplit(weights);
DataFrame train = splitData[0];
DataFrame test = splitData[1];

5、定義XGBoost模型

GBTRegressor gbt = new GBTRegressor()
    .setLabelCol("output")
    .setFeaturesCol("features")
    .setMaxIter(100)
    .setStepSize(0.1)
    .setMaxDepth(6)
    .setLossType("squared")
    .setFeatureSubsetStrategy("auto");

6、構建管道

Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{gbt});

7、訓練模型

GBTRegressionModel model = (GBTRegressionModel) pipeline.fit(train).stages()[0];

8、進行預測並評估模型

DataFrame predictions = model.transform(test);
predictions.show(false);

RegressionEvaluator evaluator = new RegressionEvaluator()
    .setMetricName("rmse")
    .setLabelCol("output")
    .setPredictionCol("prediction");

double rmse = evaluator.evaluate(predictions);
System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);

以上就是Java語言中基於SparkML的XGBoost演算法實現的示例代碼。需要註意的是,這裡使用了GBTRegressor作為XGBoost的實現方式,但是也可以使用其他實現方式,例如XGBoostRegressor或者XGBoostClassification。

三、完整代碼

import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor};
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SparkSession;
import java.util.Arrays;

public class XGBoostExample {

    public static void main(String[] args) {
        SparkSession spark = SparkSession.builder().appName("XGBoost").master("local[*]").getOrCreate();

        // 載入數據
        DataFrame data = spark.read().option("header", "true").option("inferSchema", "true").csv("data.csv");
        data.printSchema();
        data.show(false);

        // 準備特征向量
        String[] featureCols = data.columns();
        featureCols = Arrays.copyOfRange(featureCols, 0, featureCols.length - 1);
        VectorAssembler assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features");
        DataFrame inputData = assembler.transform(data).select("features", "output");
        inputData.show(false);

        // 劃分訓練集和測試集
        double[] weights = {0.7, 0.3};
        DataFrame[] splitData = inputData.randomSplit(weights);
        DataFrame train = splitData[0];
        DataFrame test = splitData[1];

        // 定義XGBoost模型
        GBTRegressor gbt = new GBTRegressor()
                .setLabelCol("output")
                .setFeaturesCol("features")
                .setMaxIter(100)
                .setStepSize(0.1)
                .setMaxDepth(6)
                .setLossType("squared")
                .setFeatureSubsetStrategy("auto");

        // 構建管道
        Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{gbt});

        // 訓練模型
        GBTRegressionModel model = (GBTRegressionModel) pipeline.fit(train).stages()[0];

        // 進行預測並評估模型
        DataFrame predictions = model.transform(test);
        predictions.show(false);

        RegressionEvaluator evaluator = new RegressionEvaluator()
                .setMetricName("rmse")
                .setLabelCol("output")
                .setPredictionCol("prediction");

        double rmse = evaluator.evaluate(predictions);
        System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);

        spark.stop();
    }
}

在運行代碼之前需要將數據文件data.csv放置到程式所在目錄下,以便載入數據。另外,需要將代碼中的相關路徑和參數按照實際情況進行修改。 


您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • RDIF.vNext,全新低代碼快速開發集成框架平臺,給用戶和開發者最佳的.Net框架平臺方案,為企業快速構建垮平臺、企業級的應用提供強大支持。框架採用最新主流技術開發(.Net6+/Vue前後端分離,支持分散式部署,跨平臺運行),前後端分離架構。支持常用多種資料庫類型,支持Web、App、客戶端應... ...
  • 軟鏈接與硬鏈接是用來乾什麼的呢? 為解決文件的共用使用,Linux 系統引入了兩種鏈接:硬鏈接 (hard link) 與軟鏈接(又稱符號鏈接,即 soft link 或 symbolic link)。鏈接為 Linux 系統解決了文件的共用使用,還帶來了隱藏文件路徑、增加許可權安全及節省存儲等好處。 ...
  • 最近筆記本重置後,台式使用“遠程桌面連接”遠程筆記本失敗了,總是提示“登錄沒有成功”。 開始自查:win10專業版,允許遠程的相關設置也都開了,連接的ip正確,也沒問題。因為我的筆記本用戶是用微軟賬戶登錄的,遠程時用戶名和密碼也要用微軟用戶名和密碼(不是那個PIN碼)。再三確認用戶名和密碼是對的後, ...
  • 問題描述 對於打開分區提示需要格式化的情況,右擊屬性查看時,文件系統變成了RAW了,沒有關係很好恢復,千萬不要格式化。 問題分析 可以看到該分區說明分區表沒有問題,這是由於DBR扇區(即啟動扇區)損壞造成的。 以上聽不懂分析沒有關係,對你的恢復影響不大。 有兩種方法恢復: 1、用軟體自動進行修複,如 ...
  • 在上篇文章 《細節拉滿,80 張圖帶你一步一步推演 slab 記憶體池的設計與實現 》中,筆者從 slab cache 的總體架構演進角度以及 slab cache 的運行原理角度為大家勾勒出了 slab cache 的總體架構視圖,基於這個視圖詳細闡述了 slab cache 的記憶體分配以及釋放原理 ...
  • DROP TABLE IF EXISTS qrtz_blob_triggers; DROP TABLE IF EXISTS qrtz_calendars; DROP TABLE IF EXISTS qrtz_cron_triggers; DROP TABLE IF EXISTS qrtz_fired ...
  • 鎖屏面試題百日百刷,每個工作日堅持更新面試題。請看到最後就能獲取你想要的,接下來的是今日的面試題: 1.解釋一下,在數據製作過程中,你如何能從Kafka得到準確的信息? 在數據中,為了精確地獲得Kafka的消息,你必須遵循兩件事: 在數據消耗期間避免重覆,在數據生產過程中避免重覆。 這裡有兩種方法, ...
  • 1. 背景 已知數據集為: 目的: 計算每個uid的連續活躍天數,並且每一段活躍期內的開始時間和結束時間 2. 步驟 第一步:處理數據集 處理數據集,使其滿足每個uid每個日期只有一條數據。 第二步:以uid為主鍵,按照日期進行排序,計算row_number. SELECT uid ,`徵信查詢日期 ...
一周排行
    -Advertisement-
    Play Games
  • 移動開發(一):使用.NET MAUI開發第一個安卓APP 對於工作多年的C#程式員來說,近來想嘗試開發一款安卓APP,考慮了很久最終選擇使用.NET MAUI這個微軟官方的框架來嘗試體驗開發安卓APP,畢竟是使用Visual Studio開發工具,使用起來也比較的順手,結合微軟官方的教程進行了安卓 ...
  • 前言 QuestPDF 是一個開源 .NET 庫,用於生成 PDF 文檔。使用了C# Fluent API方式可簡化開發、減少錯誤並提高工作效率。利用它可以輕鬆生成 PDF 報告、發票、導出文件等。 項目介紹 QuestPDF 是一個革命性的開源 .NET 庫,它徹底改變了我們生成 PDF 文檔的方 ...
  • 項目地址 項目後端地址: https://github.com/ZyPLJ/ZYTteeHole 項目前端頁面地址: ZyPLJ/TreeHoleVue (github.com) https://github.com/ZyPLJ/TreeHoleVue 目前項目測試訪問地址: http://tree ...
  • 話不多說,直接開乾 一.下載 1.官方鏈接下載: https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads 2.在下載目錄中找到下麵這個小的安裝包 SQL2022-SSEI-Dev.exe,運行開始下載SQL server; 二. ...
  • 前言 隨著物聯網(IoT)技術的迅猛發展,MQTT(消息隊列遙測傳輸)協議憑藉其輕量級和高效性,已成為眾多物聯網應用的首選通信標準。 MQTTnet 作為一個高性能的 .NET 開源庫,為 .NET 平臺上的 MQTT 客戶端與伺服器開發提供了強大的支持。 本文將全面介紹 MQTTnet 的核心功能 ...
  • Serilog支持多種接收器用於日誌存儲,增強器用於添加屬性,LogContext管理動態屬性,支持多種輸出格式包括純文本、JSON及ExpressionTemplate。還提供了自定義格式化選項,適用於不同需求。 ...
  • 目錄簡介獲取 HTML 文檔解析 HTML 文檔測試參考文章 簡介 動態內容網站使用 JavaScript 腳本動態檢索和渲染數據,爬取信息時需要模擬瀏覽器行為,否則獲取到的源碼基本是空的。 本文使用的爬取步驟如下: 使用 Selenium 獲取渲染後的 HTML 文檔 使用 HtmlAgility ...
  • 1.前言 什麼是熱更新 游戲或者軟體更新時,無需重新下載客戶端進行安裝,而是在應用程式啟動的情況下,在內部進行資源或者代碼更新 Unity目前常用熱更新解決方案 HybridCLR,Xlua,ILRuntime等 Unity目前常用資源管理解決方案 AssetBundles,Addressable, ...
  • 本文章主要是在C# ASP.NET Core Web API框架實現向手機發送驗證碼簡訊功能。這裡我選擇是一個互億無線簡訊驗證碼平臺,其實像阿裡雲,騰訊雲上面也可以。 首先我們先去 互億無線 https://www.ihuyi.com/api/sms.html 去註冊一個賬號 註冊完成賬號後,它會送 ...
  • 通過以下方式可以高效,並保證數據同步的可靠性 1.API設計 使用RESTful設計,確保API端點明確,並使用適當的HTTP方法(如POST用於創建,PUT用於更新)。 設計清晰的請求和響應模型,以確保客戶端能夠理解預期格式。 2.數據驗證 在伺服器端進行嚴格的數據驗證,確保接收到的數據符合預期格 ...