K-Means演算法是一種基於距離的聚類演算法,採用迭代的方法,計算出K個聚類中心,把若幹個點聚成K類。 用Spark實現K-Means演算法,首先修改pom文件,引入機器學習MLlib包: 代碼: 使用textFile()方法裝載數據集,獲得RDD,再使用KMeans.train()方法根據RDD、K值 ...
K-Means演算法是一種基於距離的聚類演算法,採用迭代的方法,計算出K個聚類中心,把若幹個點聚成K類。
用Spark實現K-Means演算法,首先修改pom文件,引入機器學習MLlib包:
<dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-mllib_2.10</artifactId> <version>1.6.0</version> </dependency>
代碼:
import org.apache.log4j.{Level,Logger} import org.apache.spark.{SparkContext, SparkConf} import org.apache.spark.mllib.clustering.KMeans import org.apache.spark.mllib.linalg.Vectors object Kmeans { def main(args:Array[String]) = { // 屏蔽日誌 Logger.getLogger("org.apache.spark").setLevel(Level.WARN) Logger.getLogger("org.apache.jetty.server").setLevel(Level.OFF) // 設置運行環境 val conf = new SparkConf().setAppName("K-Means").setMaster("spark://master:7077") .setJars(Seq("E:\\Intellij\\Projects\\SimpleGraphX\\SimpleGraphX.jar")) val sc = new SparkContext(conf) // 裝載數據集 val data = sc.textFile("hdfs://master:9000/kmeans_data.txt", 1) val parsedData = data.map(s => Vectors.dense(s.split(" ").map(_.toDouble))) // 將數據集聚類,2個類,20次迭代,形成數據模型 val numClusters = 2 val numIterations = 20 val model = KMeans.train(parsedData, numClusters, numIterations) // 數據模型的中心點 println("Cluster centres:") for(c <- model.clusterCenters) { println(" " + c.toString) } // 使用誤差平方之和來評估數據模型 val cost = model.computeCost(parsedData) println("Within Set Sum of Squared Errors = " + cost) // 使用模型測試單點數據 println("Vectors 7.3 1.5 10.9 is belong to cluster:" + model.predict(Vectors.dense("7.3 1.5 10.9".split(" ") .map(_.toDouble)))) println("Vectors 4.2 11.2 2.7 is belong to cluster:" + model.predict(Vectors.dense("4.2 11.2 2.7".split(" ") .map(_.toDouble)))) println("Vectors 18.0 4.5 3.8 is belong to cluster:" + model.predict(Vectors.dense("1.0 14.5 73.8".split(" ") .map(_.toDouble)))) // 返回數據集和結果 val result = data.map { line => val linevectore = Vectors.dense(line.split(" ").map(_.toDouble)) val prediction = model.predict(linevectore) line + " " + prediction }.collect.foreach(println) sc.stop } }
使用textFile()方法裝載數據集,獲得RDD,再使用KMeans.train()方法根據RDD、K值和迭代次數得到一個KMeans模型。得到KMeans模型以後,可以判斷一組數據屬於哪一個類。具體方法是用Vectors.dense()方法生成一個Vector,然後用KMeans.predict()方法就可以返回屬於哪一個類。
運行結果:
Cluster centres:
[6.062499999999999,6.7124999999999995,11.5]
[3.5,12.2,60.0]
Within Set Sum of Squared Errors = 943.2074999999998
Vectors 7.3 1.5 10.9 is belong to cluster:0
Vectors 4.2 11.2 2.7 is belong to cluster:0
Vectors 18.0 4.5 3.8 is belong to cluster:1
0.0 0.0 5.0 0
0.1 10.1 0.1 0
1.2 5.2 13.5 0
9.5 9.0 9.0 0
9.1 9.1 9.1 0
19.2 9.4 29.2 0
5.8 3.0 18.0 0
3.5 12.2 60.0 1
3.6 7.9 8.1 0