吴裕雄--天生自然python机器学习:决策树算法
我们经常使用决策树处理分类问题’近来的调查表明决策树也是最经常使用的数据挖掘算法。
它之所以如此流行,一个很重要的原因就是使用者基本上不用了解机器学习算法,也不用深究它
是如何工作的。
K-近邻算法可以完成很多分类任务,但是它最大的缺点就是无法给出数据的内
在含义,决策树的主要优势就在于数据形式非常容易理解。
决策树很多任务都
是为了数据中所蕴含的知识信息,因此决策树可以使用不熟悉的数据集合,并从中提取出一系列
规则,机器学习算法最终将使用这些机器从数据集中创造的规则。专家系统中经常使用决策树,
而且决策树给出结果往往可以匹敌在当前领域具有几十年工作经验的人类专家。
决策树的构造
首先我们讨论数学上如何使用
信息论划分数据集,然后编写代码将理论应用到具体的数据集上,最后编写代码构建决策树。
在构造决策树时,我们需要解决的第一个问题就是,当前数据集上哪个特征在划分数据分类
时起决定性作用。为了找到决定性的特征,划分出最好的结果,我们必须评估每个特征。完成测
试之后,原始数据集就被划分为几个数据子集。这些数据子集会分布在第一个决策点的所有分支
上。如果某个分支下的数据属于同一类型,则当前无需阅读的垃圾邮件已经正确地划分数据分类,
无需进一步对数据集进行分割。如果数据子集内的数据不属于同一类型,则需要重复划分数据子
集的过程。如何划分数据子集的算法和划分原始数据集的方法相同,直到所有具有相同类型的数
据均在一个数据子集内。
信息増益
划分数据集的大原则是:将无序的数据变得更加有序。我们可以使用多种方法划分数据集,
但是每种方法都有各自的优缺点。组织杂乱无章数据的一种方法就是使用信息论度量信息,信息
论是量化处理信息的分支科学。我们可以在划分数据之前使用信息论量化度量信息的内容。
在划分数据集之前之后信息发生的变化称为信息增益,知道如何计算信息增益,我们就可以
计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。
在可以评测哪种数据划分方式是最好的数据划分之前,我们必须学习如何计算信息增益。集
合信息的度量方式称为香农熵或者简称为熵,这个名字来源于信息论之父克劳德•香农。
计算给定数据集的香农熵
def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: #the the number of unique elements and their occurance currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key])/numEntries shannonEnt -= prob * log(prob,2) #log base 2 return shannonEnt
import operator from math import log def createDataSet(): dataSet = [[1, 1, ‘yes‘], [1, 1, ‘yes‘], [1, 0, ‘no‘], [0, 1, ‘no‘], [0, 1, ‘no‘]] labels = [‘no surfacing‘,‘flippers‘] #change to discrete values return dataSet, labels dataSet, labels = createDataSet() print(dataSet) print(labels)
shannonEnt = calcShannonEnt(dataSet) print(shannonEnt)
熵越高,则混合的数据也越多,我们可以在数据集中添加更多的分类,观察熵是如何变化的。
得到熵之后,我们就可以按照获取最大信息增益的方法划分数据集.
分类算法除了需要测量信息熵,还需要划分数
据集,度量花费数据集的熵,以便判断当前是否正确地划分了数据集。我们将对每个特征划分数
据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方式。
def splitDataSet(dataSet, axis, value): retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reducedFeatVec = featVec[:axis] #chop out axis used for splitting reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSet
retDataSet = splitDataSet(dataSet,0,0) print(retDataSet)
retDataSet = splitDataSet(dataSet,0,0) print(retDataSet)
选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0]) - 1 #the last column is used for the labels baseEntropy = calcShannonEnt(dataSet) bestInfoGain = 0.0; bestFeature = -1 for i in range(numFeatures): #iterate over all the features featList = [example[i] for example in dataSet]#create a list of all the examples of this feature uniqueVals = set(featList) #get a set of unique values newEntropy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob * calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy #calculate the info gain; ie reduction in entropy if (infoGain > bestInfoGain): #compare this to the best gain so far bestInfoGain = infoGain #if better than current best, set to best bestFeature = i return bestFeature
bestFeature = chooseBestFeatureToSplit(dataSet) print(bestFeature)
递归构建决策树
从数据集构造决策树算法所需要的子功能模块,其工作原理如下:得到
原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在大于
两个分支的数据集划分。第一次划分之后,数据将被向下传递到树分支的下一个节点,在这个节
点上,我们可以再次划分数据。因此我们可以采用递归的原则处理数据集。
递归结束的条件是:程序遍历完所有划分数据集的属性,或者每个分支下的所有实例都具有
相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节
点的数据必然属于叶子节点的分类.
如果数据集已经处理了所有属性,但是类标签依然不是唯一
的,此时我们需要决定如何定义该叶子节点,在这种情况下,我们通常会采用多数表决的方法决
定该叶子节点的分类。
def majorityCnt(classList): classCount={} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0]
创建树的函数代码
def createTree(dataSet,labels): classList = [example[-1] for example in dataSet] if(classList.count(classList[0]) == len(classList)): return classList[0] if len(dataSet[0]) == 1: return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}} del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels) return myTree
myTree = createTree(dataSet, labels) print(myTree)
在 Python中使用Matplotlib注解绘制树形图
决策树的主要优点就是直观
易于理解,如果不能将其直观地显示出来,就无法发挥其优势。虽然前面章节我们使用的图形库
已经非常强大,但是Pyth0n并没有提供绘制树的工具,因此我们必须自己绘制树形图。
Matplotlib 注解
Matplotlib提供了一个注解工具annotations非常有用,它可以在数据图形上添加文本注
释。注解通常用于解释数据的内容。由于数据上面直接存在文本描述非常丑陋,因此工具内嵌支
持带箭头的划线工具,使得我们可以在其他恰当的地方指向数据位置,并在此处添加描述信息,
解释数据内容。
import matplotlib.pyplot as plt decisionNode = dict(boxstyle="sawtooth", fc="0.8") leafNode = dict(boxstyle="round4", fc="0.8") arrow_args = dict(arrowstyle="<-") def getNumLeafs(myTree): numLeafs = 0 for i in myTree.keys(): firstStr = i break secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__==‘dict‘: numLeafs += getNumLeafs(secondDict[key]) else: numLeafs +=1 return numLeafs def getTreeDepth(myTree): maxDepth = 0 for i in myTree.keys(): firstStr = i break secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__==‘dict‘: thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords=‘axes fraction‘,xytext=centerPt, textcoords=‘axes fraction‘,va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) def plotTree(myTree, parentPt, nodeTxt): numLeafs = getNumLeafs(myTree) depth = getTreeDepth(myTree) for i in myTree.keys(): firstStr = i break cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__==‘dict‘: plotTree(secondDict[key],cntrPt,str(key)) else: plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD def createPlot(inTree): fig = plt.figure(1, facecolor=‘white‘) fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), ‘‘) plt.show() def retrieveTree(i): listOfTrees =[{‘no surfacing‘: {0: ‘no‘, 1: {‘flippers‘: {0: ‘no‘, 1: ‘yes‘}}}}, {‘no surfacing‘: {0: ‘no‘, 1: {‘flippers‘: {0: {‘head‘: {0: ‘no‘, 1: ‘yes‘}}, 1: ‘no‘}}}} ] return listOfTrees[i] thisTree = retrieveTree(0) createPlot(thisTree) thisTree = retrieveTree(1) createPlot(thisTree)
测试算法:使用决策树执行分类
依靠训练数据构造了决策树之后,我们可以将它用于实际数据的分类。在执行数据分类时,
需要决策树以及用于构造树的标签向量。然后,程序比较测试数据与决策树上的数值,递归执行
该过程直到进人叶子节点;最后将测试数据定义为叶子节点所属的类型。
def classify(inputTree,featLabels,testVec): for i in inputTree.keys(): firstStr = i break secondDict = inputTree[firstStr] featIndex = featLabels.index(firstStr) key = testVec[featIndex] valueOfFeat = secondDict[key] if isinstance(valueOfFeat, dict): classLabel = classify(valueOfFeat, featLabels, testVec) else: classLabel = valueOfFeat return classLabel
featLabels = [‘no surfacing‘, ‘flippers‘] classLabel = classify(myTree,featLabels,[1,1]) print(classLabel)
使 用 算 法 :决策树的存储
构造决策树是很耗时的任务,即使处理很小的数据集,如前面的样本数据,也要花费几秒的
时间,如果数据集很大,将会耗费很多计算时间。然而用创建好的决策树解决分类问题,贝何以
很快完成。因此,为了节省计算时间,最好能够在每次执行分类时调用巳经构造好的决策树。
使用pickle块存储决策树
import pickle def storeTree(inputTree,filename): fw = open(filename,‘wb‘) pickle.dump(inputTree,fw) fw.close() def grabTree(filename): fr = open(filename,‘rb‘) return pickle.load(fr) filename = "E:\\mytree.txt" storeTree(myTree,filename) mySecTree = grabTree(filename) print(mySecTree)
featLabels = [‘no surfacing‘, ‘flippers‘] classLabel = classify(mySecTree,featLabels,[0,0]) print(classLabel)