決策樹是一種非常經典的分類器,它的作用原理有點類似於我們玩的猜謎游戲。比如猜一個動物: 問:這個動物是陸生動物嗎? 答:是的。 問:這個動物有鰓嗎? 答:沒有。 這樣的兩個問題順序就有些顛倒,因為一般來說陸生動物是沒有鰓的(記得應該是這樣的,如有錯誤歡迎指正)。所以玩這種游戲,提問的順序很重要,爭取 ...
決策樹是一種非常經典的分類器,它的作用原理有點類似於我們玩的猜謎游戲。比如猜一個動物:
問:這個動物是陸生動物嗎?
答:是的。
問:這個動物有鰓嗎?
答:沒有。
這樣的兩個問題順序就有些顛倒,因為一般來說陸生動物是沒有鰓的(記得應該是這樣的,如有錯誤歡迎指正)。所以玩這種游戲,提問的順序很重要,爭取每次都能夠獲得儘可能多的信息量。
AllElectronics顧客資料庫標記類的訓練元組 | |||||
RID | age | income | student | credit_rating | Class: buys_computer |
1 | youth | high | no | fair | no |
2 | youth | high | no | excellent | no |
3 | middle_aged | high | no | fair | yes |
4 | senior | medium | no | fair | yes |
5 | senior | low | yes | fair | yes |
6 | senior | low | yes | excellent | no |
7 | middle_aged | low | yes | excellent | yes |
8 | youth | medium | no | fair | no |
9 | youth | low | yes | fair | yes |
10 | senior | medium | yes | fair | yes |
11 | youth | medium | yes | excellent | yes |
12 | middle_aged | medium | no | excellent | yes |
13 | middle_aged | high | yes | fair | yes |
14 | senior | medium | no | excellent | no |
以AllElectronics顧客資料庫標記類的訓練元組為例。我們想要以這些樣本為訓練集,訓練我們的決策樹模型,以此來挖掘出顧客是否會購買電腦的決策模式。
在決策樹ID3演算法中,計算信息度的公式如下:
$$Info_A(D) = \sum_{j=1}^v\frac{|D_j|}{D} \times Info(D_j)$$
計算信息增益的公式如下:
$$Gain(A) = Info(D) - Info_A(D)$$
按照公式,在要進行分類的類別變數中,有5個“no”和9個“yes”,因此期望信息為:
$$Info(D)=-\frac{9}{14}log_2\frac{9}{14}-\frac{5}{14}log_2\frac{5}{14}=0.940$$
首先計算特征age的期望信息:
$$Info_{age}(D)=\frac{5}{14} \times (-\frac{2}{5}log_2\frac{2}{5} - \frac{3}{5}log_2\frac{3}{5})+\frac{4}{14} \times (-\frac{4}{4}log_2\frac{4}{4} - \frac{0}{4}log_2\frac{0}{4})+\frac{5}{14} \times (-\frac{3}{5}log_2\frac{3}{5} - \frac{2}{5}log_2\frac{2}{5})$$
因此,如果按照age進行劃分,則獲得的信息增益為:
$$Gain(age) = Info(D)-Info_{age}(D) = 0.940-0.694=0.246$$
依次計算以income、student和credit_rating來分裂的信息增益,由此選擇能夠帶來最大信息增益的變數,在當
前結點選擇以以該變數的取值進行分裂。遞歸地進行執行即可生成決策樹。更加詳細的內容可以參考:
https://en.wikipedia.org/wiki/Decision_tree
C#代碼的實現如下:
1 using System; 2 using System.Collections.Generic; 3 using System.Linq; 4 namespace MachineLearning.DecisionTree 5 { 6 public class DecisionTreeID3<T> where T : IEquatable<T> 7 { 8 T[,] Data; 9 string[] Names; 10 int Category; 11 T[] CategoryLabels; 12 DecisionTreeNode<T> Root; 13 public DecisionTreeID3(T[,] data, string[] names, T[] categoryLabels) 14 { 15 Data = data; 16 Names = names; 17 Category = data.GetLength(1) - 1;//類別變數需要放在最後一列 18 CategoryLabels = categoryLabels; 19 } 20 public void Learn() 21 { 22 int nRows = Data.GetLength(0); 23 int nCols = Data.GetLength(1); 24 int[] rows = new int[nRows]; 25 int[] cols = new int[nCols]; 26 for (int i = 0; i < nRows; i++) rows[i] = i; 27 for (int i = 0; i < nCols; i++) cols[i] = i; 28 Root = new DecisionTreeNode<T>(-1, default(T)); 29 Learn(rows, cols, Root); 30 DisplayNode(Root); 31 } 32 public void DisplayNode(DecisionTreeNode<T> Node, int depth = 0) 33 { 34 if (Node.Label != -1) 35 Console.WriteLine("{0} {1}: {2}", new string('-', depth * 3), Names[Node.Label], Node.Value); 36 foreach (var item in Node.Children) 37 DisplayNode(item, depth + 1); 38 } 39 private void Learn(int[] pnRows, int[] pnCols, DecisionTreeNode<T> Root) 40 { 41 var categoryValues = GetAttribute(Data, Category, pnRows); 42 var categoryCount = categoryValues.Distinct().Count(); 43 if (categoryCount == 1) 44 { 45 var node = new DecisionTreeNode<T>(Category, categoryValues.First()); 46 Root.Children.Add(node); 47 } 48 else 49 { 50 if (pnRows.Length == 0) return; 51 else if (pnCols.Length == 1) 52 { 53 //投票~ 54 //多數票表決制 55 var Vote = categoryValues.GroupBy(i => i).OrderBy(i => i.Count()).First(); 56 var node = new DecisionTreeNode<T>(Category, Vote.First()); 57 Root.Children.Add(node); 58 } 59 else 60 { 61 var maxCol = MaxEntropy(pnRows, pnCols); 62 var attributes = GetAttribute(Data, maxCol, pnRows).Distinct(); 63 string currentPrefix = Names[maxCol]; 64 foreach (var attr in attributes) 65 { 66 int[] rows = pnRows.Where(irow => Data[irow, maxCol].Equals(attr)).ToArray(); 67 int[] cols = pnCols.Where(i => i != maxCol).ToArray(); 68 var node = new DecisionTreeNode<T>(maxCol, attr); 69 Root.Children.Add(node); 70 Learn(rows, cols, node);//遞歸生成決策樹 71 } 72 } 73 } 74 } 75 public double AttributeInfo(int attrCol, int[] pnRows) 76 { 77 var tuples = AttributeCount(attrCol, pnRows); 78 var sum = (double)pnRows.Length; 79 double Entropy = 0.0; 80 foreach (var tuple in tuples) 81 { 82 int[] count = new int[CategoryLabels.Length]; 83 foreach (var irow in pnRows) 84 if (Data[irow, attrCol].Equals(tuple.Item1)) 85 { 86 int index = Array.IndexOf(CategoryLabels, Data[irow, Category]); 87 count[index]++; 88 } 89 double k = 0.0; 90 for (int i = 0; i < count.Length; i++) 91 { 92 double frequency = count[i] / (double)tuple.Item2; 93 double t = -frequency * Log2(frequency); 94 k += t; 95 } 96 double freq = tuple.Item2 / sum; 97 Entropy += freq * k; 98 } 99 return Entropy; 100 } 101 public double CategoryInfo(int[] pnRows) 102 { 103 var tuples = AttributeCount(Category, pnRows); 104 var sum = (double)pnRows.Length; 105 double Entropy = 0.0; 106 foreach (var tuple in tuples) 107 { 108 double frequency = tuple.Item2 / sum; 109 double t = -frequency * Log2(frequency); 110 Entropy += t; 111 } 112 return Entropy; 113 } 114 private static IEnumerable<T> GetAttribute(T[,] data, int col, int[] pnRows) 115 { 116 foreach (var irow in pnRows) 117 yield return data[irow, col]; 118 } 119 private static double Log2(double x) 120 { 121 return x == 0.0 ? 0.0 : Math.Log(x, 2.0); 122 } 123 public int MaxEntropy(int[] pnRows, int[] pnCols) 124 { 125 double cateEntropy = CategoryInfo(pnRows); 126 int maxAttr = 0; 127 double max = double.MinValue; 128 foreach (var icol in pnCols) 129 if (icol != Category) 130 { 131 double Gain = cateEntropy - AttributeInfo(icol, pnRows); 132 if (max < Gain) 133 { 134 max = Gain; 135 maxAttr = icol; 136 } 137 } 138 return maxAttr; 139 } 140 public IEnumerable<Tuple<T, int>> AttributeCount(int col, int[] pnRows) 141 { 142 var tuples = from n in GetAttribute(Data, col, pnRows) 143 group n by n into i 144 select Tuple.Create(i.First(), i.Count()); 145 return tuples; 146 } 147 } 148 }
調用方法如下:
1 using System; 2 using System.Collections.Generic; 3 using System.Linq; 4 using System.Text; 5 using System.Threading.Tasks; 6 using MachineLearning.DecisionTree; 7 namespace MachineLearning 8 { 9 class Program 10 { 11 static void Main(string[] args) 12 { 13 var data = new string[,] 14 { 15 {"youth","high","no","fair","no"}, 16 {"youth","high","no","excellent","no"}, 17 {"middle_aged","high","no","fair","yes"}, 18 {"senior","medium","no","fair","yes"}, 19 {"senior","low","yes","fair","yes"}, 20 {"senior","low","yes","excellent","no"}, 21 {"middle_aged","low","yes","excellent","yes"}, 22 {"youth","medium","no","fair","no"}, 23 {"youth","low","yes","fair","yes"}, 24 {"senior","medium","yes","fair","yes"}, 25 {"youth","medium","yes","excellent","yes"}, 26 {"middle_aged","medium","no","excellent","yes"}, 27 {"middle_aged","high","yes","fair","yes"}, 28 {"senior","medium","no","excellent","no"} 29 }; 30 var names = new string[] { "age", "income", "student", "credit_rating", "Class: buys_computer" }; 31 var tree = new DecisionTreeID3<string>(data, names, new string[] { "yes", "no" }); 32 tree.Learn(); 33 Console.ReadKey(); 34 } 35 } 36 }
運行結果:
註:作者本人也在學習中,能力有限,如有錯漏還請不吝指正。轉載請註明作者。