決策樹ID3演算法的java實現(基本試用所有的ID3)

来源:http://www.cnblogs.com/tk55/archive/2016/12/28/6231206.html
-Advertisement-
Play Games

已知:流感訓練數據集,預定義兩個類別; 求:用ID3演算法建立流感的屬性描述決策樹 流感訓練數據集 No. 頭痛 肌肉痛 體溫 患流感 1 是(1) 是(1) 正常(0) 否(0) 2 是(1) 是(1) 高(1) 是(1) 3 是(1) 是(1) 很高(2) 是(1) 4 否(0) 是(1) 正常( ...


已知:流感訓練數據集,預定義兩個類別;

求:用ID3演算法建立流感的屬性描述決策樹

流感訓練數據集

No.

頭痛

肌肉痛

體溫

患流感

1

是(1)

是(1)

正常(0)

否(0)

2

是(1)

是(1)

高(1)

是(1)

3

是(1)

是(1)

很高(2)

是(1)

4

否(0)

是(1)

正常(0)

否(0)

5

否(0)

否(0)

高(1)

否(0)

6

否(0)

是(1)

很高(2)

是(1)

7

是(1)

否(0)

高(1)

是(1)

                                                                

 

 

原理分析:

 

在決策樹的每一個非葉子結點劃分之前,先計算每一個屬性所帶來的信息增益,選擇最大信息增益的屬性來劃分,因為信息增益越大,區分樣本的能力就越強,越具有代表性其中。

 

信息熵計算:

 

 

信息增益:

 

 

計算的結果(草稿上的字醜別噴):

 

--------------------------------------------------------------------------------------------------------------------------------------------

 

 

*************************************************************************************************************

************************實現*********************************************

 

 

package ID3Tree;
import java.util.Comparator;;
@SuppressWarnings("rawtypes")
public class Comparisons implements Comparator 
{
    public int compare(Object a, Object b) throws ClassCastException{
        String str1 = (String)a;
        String str2 = (String)b;
        return str1.compareTo(str2);
    }
}

 

package ID3Tree;

public class Entropy {
    //信息熵
    public static double getEntropy(int x, int total)
    {
        if (x == 0)
        {
            return 0;
        }
        double x_pi = getShang(x,total);
        return -(x_pi*Logs(x_pi));
    }

    public static double Logs(double y)
    {
        return Math.log(y) / Math.log(2);
    }
    
    
    public static double getShang(int x, int total)
    {
        return x * Double.parseDouble("1.0") / total;
    }
}

 

package ID3Tree;

public class TreeNode {
    //父節點
    TreeNode parent;
    //指向父節點的屬性
    String parentAttribute;
    String nodeName;
    String[] attributes;
    TreeNode[] childNodes;
}

 

package ID3Tree;
import java.util.*;

public class UtilID3 {
    TreeNode root;
    private boolean[] flag;
    //訓練集
    private Object[] trainArrays;
    //節點索引
    private int nodeIndex;
    public static void main(String[] args)
    {
        //初始化訓練集數組
        Object[] arrays = new Object[]{
                new String[]{"是","是","正常","否"},
                new String[]{"是","是","高","是"},
                new String[]{"是","是","很高","是"},
                new String[]{"否","是","正常","否"},
                new String[]{"否","否","高","否"},
                new String[]{"否","是","很高","是"},
                new String[]{"是","否","高","是"}};
        UtilID3 ID3Tree = new UtilID3();
        ID3Tree.create(arrays, 3);
    }

    //創建
    public void create(Object[] arrays, int index)
    {
        this.trainArrays = arrays;
        initial(arrays, index);
        createDTree(arrays);
        printDTree(root);
    }
    
    //初始化
    public void initial(Object[] dataArray, int index)
    {
        this.nodeIndex = index;
        
        //數據初始化
        this.flag = new boolean[((String[])dataArray[0]).length];
        for (int i = 0; i<this.flag.length; i++)
        {
            if (i == index)
            {
                this.flag[i] = true;
            }
            else
            {
                this.flag[i] = false;
            }
        }
    }
    
    //創建決策樹
    public void createDTree(Object[] arrays)
    {
        Object[] ob = getMaxGain(arrays);
        if (this.root == null)
        {
            this.root = new TreeNode();
            root.parent = null;
            root.parentAttribute = null;
            root.attributes = getAttributes(((Integer)ob[1]).intValue());
            root.nodeName = getNodeName(((Integer)ob[1]).intValue());
            root.childNodes = new TreeNode[root.attributes.length];
            insert(arrays, root);
        }
    }
    
    //插入決策樹
    public void insert(Object[] arrays, TreeNode parentNode)
    {
        String[] attributes = parentNode.attributes;
        for (int i = 0; i < attributes.length; i++)
        {
            Object[] Arrays = pickUpAndCreateArray(arrays, attributes[i],getNodeIndex(parentNode.nodeName));
            Object[] info = getMaxGain(Arrays);
            double gain = ((Double)info[0]).doubleValue();
            if (gain != 0)
            {
                int index = ((Integer)info[1]).intValue();
                TreeNode currentNode = new TreeNode();
                currentNode.parent = parentNode;
                currentNode.parentAttribute = attributes[i];
                currentNode.attributes = getAttributes(index);
                currentNode.nodeName = getNodeName(index);
                currentNode.childNodes = new TreeNode[currentNode.attributes.length];
                parentNode.childNodes[i] = currentNode;
                insert(Arrays, currentNode);
            }
            else
            {
                TreeNode leafNode = new TreeNode();
                leafNode.parent = parentNode;
                leafNode.parentAttribute = attributes[i];
                leafNode.attributes = new String[0];
                leafNode.nodeName = getLeafNodeName(Arrays);
                leafNode.childNodes = new TreeNode[0];
                parentNode.childNodes[i] = leafNode;
            }
        }
    }
    
    //輸出
    public void printDTree(TreeNode node)
    {
        System.out.println(node.nodeName);
        TreeNode[] childs = node.childNodes;
        for (int i = 0; i < childs.length; i++)
        {
            if (childs[i] != null)
            {
                System.out.println("如果:"+childs[i].parentAttribute);
                printDTree(childs[i]);
            }
        }
    }
    
    //剪取數組
    public Object[] pickUpAndCreateArray(Object[] arrays, String attribute, int index)
    {
        List<String[]> list = new ArrayList<String[]>();
        for (int i = 0; i < arrays.length; i++)
        {
            String[] strs = (String[])arrays[i];
            if (strs[index].equals(attribute))
            {
                list.add(strs);
            }
        }
        return list.toArray();
    }
    
    //取得節點名
    public String getNodeName(int index)
    {
        String[] strs = new String[]{"頭痛","肌肉痛","體溫","患流感"};
        for (int i = 0; i < strs.length; i++)
        {
            if (i == index)
            {
                return strs[i];
            }
        }
        return null;
    }
    
    //取得葉子節點名
    public String getLeafNodeName(Object[] arrays)
    {
        if (arrays != null && arrays.length > 0)
        {
            String[] strs = (String[])arrays[0];
            return strs[nodeIndex];
        }
        return null;
    }
    
    //取得節點索引
    public int getNodeIndex(String name)
    {
        String[] strs = new String[]{"頭痛","肌肉痛","體溫","患流感"};
        for (int i = 0; i < strs.length; i++)
        {
            if (name.equals(strs[i]))
            {
                return i;
            }
        }
        return -1;
    }
    
    
    
    //得到最大信息增益
    public Object[] getMaxGain(Object[] arrays)
    {
        Object[] result = new Object[2];
        double gain = 0;
        int index = -1;
        for (int i = 0; i<this.flag.length; i++)
        {
            if (!this.flag[i])
            {
                double value = gain(arrays, i);
                if (gain < value)
                {
                    gain = value;
                    index = i;
                }
            }
        }
        result[0] = gain;
        result[1] = index;
        if (index != -1)
        {
            this.flag[index] = true;
        }
        return result;
    }
    
    //取得屬性數組
    public String[] getAttributes(int index)
    {
        @SuppressWarnings("unchecked")
        TreeSet<String> set = new TreeSet<String>(new Comparisons());
        for (int i = 0; i<this.trainArrays.length; i++)
        {
            String[] strs = (String[])this.trainArrays[i];
            set.add(strs[index]);
        }
        String[] result = new String[set.size()];
        return set.toArray(result);
        
    }
    
    //計算信息增益
    public double gain(Object[] arrays, int index)
    {
        String[] playBalls = getAttributes(this.nodeIndex);
        int[] counts = new int[playBalls.length];
        for (int i = 0; i<counts.length; i++)
        {
            counts[i] = 0;
        }
        
        for (int i = 0; i<arrays.length; i++)
        {
            String[] strs = (String[])arrays[i];
            for (int j = 0; j<playBalls.length; j++)
            {    
                if (strs[this.nodeIndex].equals(playBalls[j]))
                {
                    counts[j]++;
                }
            }
        }
        
        double entropyS = 0;
        for (int i = 0;i <counts.length; i++)
        {
            entropyS = entropyS + Entropy.getEntropy(counts[i], arrays.length);
        }
        
        String[] attributes = getAttributes(index);
        double total = 0;
        for (int i = 0; i<attributes.length; i++)
        {
            total = total + entropy(arrays, index, attributes[i], arrays.length);
        }
        return entropyS - total;
    }
    
    
    public double entropy(Object[] arrays, int index, String attribute, int totals)
    {
        String[] playBalls = getAttributes(this.nodeIndex);
        int[] counts = new int[playBalls.length];
        for (int i = 0; i < counts.length; i++)
        {
            counts[i] = 0;
        }
        
        for (int i = 0; i < arrays.length; i++)
        {
            String[] strs = (String[])arrays[i];
            if (strs[index].equals(attribute))
            {
                for (int k = 0; k<playBalls.length; k++)
                {
                    if (strs[this.nodeIndex].equals(playBalls[k]))
                    {
                        counts[k]++;
                    }
                }
            }
        }
        
        int total = 0;
        double entropy = 0;
        for (int i = 0; i < counts.length; i++)
        {
            total = total +counts[i];
        }
        
        for (int i = 0; i < counts.length; i++)
        {
            entropy = entropy + Entropy.getEntropy(counts[i], total);
        }
        return Entropy.getShang(total, totals)*entropy;
    }
}

 

 


您的分享是我們最大的動力!

-Advertisement-
Play Games
更多相關文章
  • VS項目中使用Nuget還原包後編譯生產還一直報錯? ...
  • 申請博客 ...
  • 在上一篇C#多線程之線程池篇1中,我們主要學習瞭如何線上程池中調用委托以及如何線上程池中執行非同步操作,在這篇中,我們將學習線程池和並行度、實現取消選項的相關知識。 三、線程池和並行度 在這一小節中,我們將學習對於大量的非同步操作,使用線程池和分別使用單獨的線程在性能上有什麼差異性。具體操作步驟如下: ...
  • 安裝了Visual Studio的那些使用微軟平臺的開發者通常能夠非常容易地操作自己的項目:打開解決方案,修改內容,設置好所有必須的文件以及配置後編譯項目。但是在構建伺服器或者持續交付系統等沒有安裝Visual Studio的環境中,編譯項目和解決方案是非常難的。 對於這一問題,微軟之前給出的方案是 ...
  • 高級FTP伺服器開發 一,作業要求 高級FTP伺服器開發 二,程式文件清單 Folder目錄:用戶文件目錄 bin目錄:程式啟動文件目錄 conf目錄:用戶配置文件目錄 core目錄:程式核心代碼目錄 log目錄:程式日誌文件目錄 三,程式流程簡圖 四,程式測試樣圖 創建賬戶 用戶登錄 基本操作 五 ...
  • 好的介面容易被正確使用,不易被誤用 考慮以下函數: void func(int year,int month,int day){ //一些操作 } 這個函數看似合理,因為參數的名字已經暴露了它的用途。但是如果只有寒暑簽名呢?如下: void func(int,int,int); 就算我告訴你,此處需... ...
  • 定義和用法 unserialize() 將已序列化的字元串還原回 PHP 的值。 序列化請使用 serialize() 函數。 定義和用法 unserialize() 將已序列化的字元串還原回 PHP 的值。 序列化請使用 serialize() 函數。 語法 unserialize(str) 參數 ...
  • 《大話數據結構》中這樣介紹冒泡排序的基本思想:兩兩比較相鄰元素的關鍵字,如果反序則交換,直到沒有反序的記錄為止。也就是相鄰元素之間兩兩比較,如果前一個值大於後一個(或者前一個值小於後一個),則交換順序,所以這樣的話,最終的結果是從小到大(或者從大到小)的排序。 當然php有非常強大的排序函數,比如s ...
一周排行
    -Advertisement-
    Play Games
  • 移動開發(一):使用.NET MAUI開發第一個安卓APP 對於工作多年的C#程式員來說,近來想嘗試開發一款安卓APP,考慮了很久最終選擇使用.NET MAUI這個微軟官方的框架來嘗試體驗開發安卓APP,畢竟是使用Visual Studio開發工具,使用起來也比較的順手,結合微軟官方的教程進行了安卓 ...
  • 前言 QuestPDF 是一個開源 .NET 庫,用於生成 PDF 文檔。使用了C# Fluent API方式可簡化開發、減少錯誤並提高工作效率。利用它可以輕鬆生成 PDF 報告、發票、導出文件等。 項目介紹 QuestPDF 是一個革命性的開源 .NET 庫,它徹底改變了我們生成 PDF 文檔的方 ...
  • 項目地址 項目後端地址: https://github.com/ZyPLJ/ZYTteeHole 項目前端頁面地址: ZyPLJ/TreeHoleVue (github.com) https://github.com/ZyPLJ/TreeHoleVue 目前項目測試訪問地址: http://tree ...
  • 話不多說,直接開乾 一.下載 1.官方鏈接下載: https://www.microsoft.com/zh-cn/sql-server/sql-server-downloads 2.在下載目錄中找到下麵這個小的安裝包 SQL2022-SSEI-Dev.exe,運行開始下載SQL server; 二. ...
  • 前言 隨著物聯網(IoT)技術的迅猛發展,MQTT(消息隊列遙測傳輸)協議憑藉其輕量級和高效性,已成為眾多物聯網應用的首選通信標準。 MQTTnet 作為一個高性能的 .NET 開源庫,為 .NET 平臺上的 MQTT 客戶端與伺服器開發提供了強大的支持。 本文將全面介紹 MQTTnet 的核心功能 ...
  • Serilog支持多種接收器用於日誌存儲,增強器用於添加屬性,LogContext管理動態屬性,支持多種輸出格式包括純文本、JSON及ExpressionTemplate。還提供了自定義格式化選項,適用於不同需求。 ...
  • 目錄簡介獲取 HTML 文檔解析 HTML 文檔測試參考文章 簡介 動態內容網站使用 JavaScript 腳本動態檢索和渲染數據,爬取信息時需要模擬瀏覽器行為,否則獲取到的源碼基本是空的。 本文使用的爬取步驟如下: 使用 Selenium 獲取渲染後的 HTML 文檔 使用 HtmlAgility ...
  • 1.前言 什麼是熱更新 游戲或者軟體更新時,無需重新下載客戶端進行安裝,而是在應用程式啟動的情況下,在內部進行資源或者代碼更新 Unity目前常用熱更新解決方案 HybridCLR,Xlua,ILRuntime等 Unity目前常用資源管理解決方案 AssetBundles,Addressable, ...
  • 本文章主要是在C# ASP.NET Core Web API框架實現向手機發送驗證碼簡訊功能。這裡我選擇是一個互億無線簡訊驗證碼平臺,其實像阿裡雲,騰訊雲上面也可以。 首先我們先去 互億無線 https://www.ihuyi.com/api/sms.html 去註冊一個賬號 註冊完成賬號後,它會送 ...
  • 通過以下方式可以高效,並保證數據同步的可靠性 1.API設計 使用RESTful設計,確保API端點明確,並使用適當的HTTP方法(如POST用於創建,PUT用於更新)。 設計清晰的請求和響應模型,以確保客戶端能夠理解預期格式。 2.數據驗證 在伺服器端進行嚴格的數據驗證,確保接收到的數據符合預期格 ...