決策樹(Decision Tree)的核心思想是:根據訓練樣本構建這樣一棵樹,使得其葉節點是分類標簽,非葉節點是判斷條件,這樣對于一個未知樣本,能在樹上找到一條路徑到達葉節點,就得到了它的分類。
舉個簡單的例子,如何識別有毒的蘑菇?如果能夠得到一棵這樣的決策樹,那么對于一個未知的蘑菇就很容易判斷出它是否有毒了。
它是什么顏色的? | -------鮮艷---------淺色---- | | 有毒 有什么氣味? | -----刺激性--------無味----- | | 有毒 安全
構建決策樹有很多算法,常用的有ID3、C4.5等。本篇以ID3為研究算法。
構建決策樹的關鍵在于每一次分支時選擇哪個特征作為分界條件。這里的原則是:選擇最能把數據變得有序的特征作為分界條件。所謂有序,是指劃分后,每一個分支集合的分類盡可能一致。用信息論的方式表述,就是選擇信息增益最大的方式劃分集合。
所謂信息增益(information gain),是指變化前后熵(entropy)的增加量。為了計算熵,需要計算所有類別所有可能值包含的信息期望值,通過下面的公式得到:

其中H為熵,n為分類數目,p(xi)是選擇該分類的概率。
根據公式,計算一個集合熵的方式為:
計算每個分類出現的次數
foreach(每一個分類)
{
計算出現概率
根據概率計算熵
累加熵
}
return 累加結果判斷如何劃分集合,方式為:
foreach(每一個特征)
{
計算按此特征切分時的熵
計算與切分前相比的信息增益
保留能產生最大增益的特征為切分方式
}
return 選定的特征構建樹節點的方法為:
if(集合沒有特征可用了)
{
按多數原則決定此節點的分類
}
else if(集合中所有樣本的分類都一致)
{
此標簽就是節點分類
}
else
{
以最佳方式切分集合
每一種可能形成當前節點的一個分支
遞歸
}OK,上C#版代碼,DataVector和上篇文章一樣,不放了,只放核心算法:
using System;
using System.Collections.Generic;
namespace MachineLearning
{
/// <summary>
/// 決策樹節點
/// </summary>
public class DecisionNode
{
/// <summary>
/// 此節點的分類標簽,為空表示此節點不是葉節點
/// </summary>
public string Label { get; set; }
/// <summary>
/// 此節點的劃分特征,為-1表示此節點是葉節點
/// </summary>
public int FeatureIndex { get; set; }
/// <summary>
/// 分支
/// </summary>
public Dictionary<string, DecisionNode> Child { get; set; }
public DecisionNode()
{
this.FeatureIndex = -1;
this.Child = new Dictionary<string, DecisionNode>();
}
}
}using System;
using System.Collections.Generic;
using System.Linq;
namespace MachineLearning
{
/// <summary>
/// 決策樹(ID3算法)
/// </summary>
public class DecisionTree
{
private DecisionNode m_Tree;
/// <summary>
/// 訓練
/// </summary>
/// <param name="trainingSet"></param>
public void Train(List<DataVector<string>> trainingSet)
{
var features = new List<int>(trainingSet[0].Dimension);
for(int i = 0;i < trainingSet[0].Dimension;++i)
features.Add(i);
//生成決策樹
m_Tree = CreateTree(trainingSet, features);
}
/// <summary>
/// 分類
/// </summary>
/// <param name="vector"></param>
/// <returns></returns>
public string Classify(DataVector<string> vector)
{
return Classify(vector, m_Tree);
}
/// <summary>
/// 分類
/// </summary>
/// <param name="vector"></param>
/// <param name="node"></param>
/// <returns></returns>
private string Classify(DataVector<string> vector, DecisionNode node)
{
var label = string.Empty;
if(!string.IsNullOrEmpty(node.Label))
{
//是葉節點,直接返回結果
label = node.Label;
}
else
{
//取需要分類的字段,繼續深入
var key = vector.Data[node.FeatureIndex];
if(node.Child.ContainsKey(key))
label = Classify(vector, node.Child[key]);
else
label = "[UNKNOWN]";
}
return label;
}
/// <summary>
/// 創建決策樹
/// </summary>
/// <param name="dataSet"></param>
/// <param name="features"></param>
/// <returns></returns>
private DecisionNode CreateTree(List<DataVector<string>> dataSet, List<int> features)
{
var node = new DecisionNode();
if(dataSet[0].Dimension == 0)
{
//所有字段已用完,按多數原則決定Label,結束分類
node.Label = GetMajorLabel(dataSet);
}
else if(dataSet.Count == dataSet.Count(d => string.Equals(d.Label, dataSet[0].Label)))
{
//如果數據集中的Label相同,結束分類
node.Label = dataSet[0].Label;
}
else
{
//挑選一個最佳分類,分割集合,遞歸
int featureIndex = ChooseBestFeature(dataSet);
node.FeatureIndex = features[featureIndex];
var uniqueValues = GetUniqueValues(dataSet, featureIndex);
features.RemoveAt(featureIndex);
foreach(var value in uniqueValues)
{
node.Child[value.ToString()] = CreateTree(SplitDataSet(dataSet, featureIndex, value), new List<int>(features));
}
}
return node;
}
/// <summary>
/// 計算給定集合的香農熵
/// </summary>
/// <param name="dataSet"></param>
/// <returns></returns>
private double ComputeShannon(List<DataVector<string>> dataSet)
{
double shannon = 0.0;
var dict = new Dictionary<string, int>();
foreach(var item in dataSet)
{
if(!dict.ContainsKey(item.Label))
dict[item.Label] = 0;
dict[item.Label] += 1;
}
foreach(var label in dict.Keys)
{
double prob = dict[label] * 1.0 / dataSet.Count;
shannon -= prob * Math.Log(prob, 2);
}
return shannon;
}
/// <summary>
/// 用給定的方式切分出數據子集
/// </summary>
/// <param name="dataSet"></param>
/// <param name="splitIndex"></param>
/// <param name="value"></param>
/// <returns></returns>
private List<DataVector<string>> SplitDataSet(List<DataVector<string>> dataSet, int splitIndex, string value)
{
var newDataSet = new List<DataVector<string>>();
foreach(var item in dataSet)
{
//只保留指定維度上符合給定值的項
if(item.Data[splitIndex] == value)
{
var newItem = new DataVector<string>(item.Dimension - 1);
newItem.Label = item.Label;
Array.Copy(item.Data, 0, newItem.Data, 0, splitIndex - 0);
Array.Copy(item.Data, splitIndex + 1, newItem.Data, splitIndex, item.Dimension - splitIndex - 1);
newDataSet.Add(newItem);
}
}
return newDataSet;
}
/// <summary>
/// 在給定的數據集上選擇一個最好的切分方式
/// </summary>
/// <param name="dataSet"></param>
/// <returns></returns>
private int ChooseBestFeature(List<DataVector<string>> dataSet)
{
int bestFeature = 0;
double bestInfoGain = 0.0;
double baseShannon = ComputeShannon(dataSet);
//遍歷每一個維度來尋找
for(int i = 0;i < dataSet[0].Dimension;++i)
{
var uniqueValues = GetUniqueValues(dataSet, i);
double newShannon = 0.0;
//遍歷此維度下的每一個可能值,切分數據集并計算熵
foreach(var value in uniqueValues)
{
var subSet = SplitDataSet(dataSet, i, value);
double prob = subSet.Count * 1.0 / dataSet.Count;
newShannon += prob * ComputeShannon(subSet);
}
//計算信息增益,保留最佳切分方式
double infoGain = baseShannon - newShannon;
if(infoGain > bestInfoGain)
{
bestInfoGain = infoGain;
bestFeature = i;
}
}
return bestFeature;
}
/// <summary>
/// 數據去重
/// </summary>
/// <param name="dataSet"></param>
/// <param name="index"></param>
/// <returns></returns>
private List<string> GetUniqueValues(List<DataVector<string>> dataSet, int index)
{
var dict = new Dictionary<string, int>();
foreach(var item in dataSet)
{
dict[item.Data[index]] = 0;
}
return dict.Keys.ToList<string>();
}
/// <summary>
/// 取多數標簽
/// </summary>
/// <param name="dataSet"></param>
/// <returns></returns>
private string GetMajorLabel(List<DataVector<string>> dataSet)
{
var dict = new Dictionary<string, int>();
foreach(var item in dataSet)
{
if(!dict.ContainsKey(item.Label))
dict[item.Label] = 0;
dict[item.Label]++;
}
string label = string.Empty;
int count = -1;
foreach(var key in dict.Keys)
{
if(dict[key] > count)
{
label = key;
count = dict[key];
}
}
return label;
}
}
}拿個例子實際檢驗一下,還是以毒蘑菇的識別為例,從這里找了點數據,http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.data ,它整理了8000多個樣本,每個樣本描述了蘑菇的22個屬性,比如形狀、氣味等等,然后給出了這個蘑菇是否可食用。
比如一行數據:p,x,s,n,t,p,f,c,n,k,e,e,s,s,w,w,p,w,o,p,k,s,u
第0個元素p表示poisonous(有毒),其它22個元素分別是蘑菇的屬性,可以參見agaricus-lepiota.names的描述,但實際上根本不用關心具體含義。以此構建樣本并測試錯誤率:
public void TestDecisionTree()
{
var trainingSet = new List<DataVector<string>>(); //訓練數據集
var testSet = new List<DataVector<string>>(); //測試數據集
//讀取數據
var file = new StreamReader("agaricus-lepiota.data", Encoding.Default);
string line = string.Empty;
int count = 0;
while((line = file.ReadLine()) != null)
{
var parts = line.Split(',');
var p = new DataVector<string>(22);
p.Label = parts[0];
for(int i = 0;i < p.Dimension;++i)
p.Data[i] = parts[i + 1];
//前7000作為訓練樣本,其余作為測試樣本
if(++count <= 7000)
trainingSet.Add(p);
else
testSet.Add(p);
}
file.Close();
//檢驗
var dt = new DecisionTree();
dt.Train(trainingSet);
int error = 0;
foreach(var p in testSet)
{
//做猜測分類,并與實際結果比較
var label = dt.Classify(p);
if(label != p.Label)
++error;
}
Console.WriteLine("Error = {0}/{1}, {2}%", error, testSet.Count, (error * 100.0 / testSet.Count));
}使用7000個樣本做訓練,1124個樣本做測試,只有4個猜測出錯,錯誤率僅有0.35%,相當不錯的結果。
生成的決策樹是這樣的:
{
"FeatureIndex": 4, //按第4個特征劃分
"Child": {
"p": {"Label": "p"}, //如果第4個特征是p,則分類為p
"a": {"Label": "e"}, //如果第4個特征是a,則分類是e
"l": {"Label": "e"},
"n": {
"FeatureIndex": 19, //如果第4個特征是n,要繼續按第19個特征劃分
"Child": {
"n": {"Label": "e"},
"k": {"Label": "e"},
"w": {
"FeatureIndex": 21,
"Child": {
"w": {"Label": "e"},
"l": {
"FeatureIndex": 2,
"Child": {
"c": {"Label": "e"},
"n": {"Label": "e"},
"w": {"Label": "p"},
"y": {"Label": "p"}
}
},
"d": {
"FeatureIndex": 1,
"Child": {
"y": {"Label": "p"},
"f": {"Label": "p"},
"s": {"Label": "e"}
}
},
"g": {"Label": "e"},
"p": {"Label": "e"}
}
},
"h": {"Label": "e"},
"r": {"Label": "p"},
"o": {"Label": "e"},
"y": {"Label": "e"},
"b": {"Label": "e"}
}
},
"f": {"Label": "p"},
"c": {"Label": "p"},
"y": {"Label": "p"},
"s": {"Label": "p"},
"m": {"Label": "p"}
}
}可以看到,實際只使用了其中的5個特征,就能做出比較精確的判斷了。
決策樹還有一個很棒的優點就是能告訴我們多個特征中哪個對判別最有用,比如上面的樹,根節點是特征4,參考agaricus-lepiota.names得知這個特征是指氣味(odor),只要有氣味,就可以直接得出結論,如果是無味的(n=none),下一個重要特征是19,即孢子印的顏色(spore-print-color)。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。