Python机器学习(十九)决策树之系列二—C4.5原理与代码实现
ID3算法缺点
它一般会优先选择有较多属性值的Feature,因为属性值多的特征会有相对较大的信息增益,信息增益反映的是,在给定一个条件以后,不确定性减少的程度,
这必然是分得越细的数据集确定性更高,也就是条件熵越小,信息增益越大。为了解决这个问题,C4.5就应运而生,它采用信息增益率来作为选择分支的准则。
C4.5算法原理
信息增益率定义为:
其中,分子为信息增益(信息增益计算可参考上一节ID3的算法原理),分母为属性X的熵。
需要注意的是,增益率准则对可取值数目较少的属性有所偏好。
所以一般这样选取划分属性:选择增益率最高的特征列作为划分属性的依据。
代码实现
与ID3代码实现不同的是:只改变计算香农熵的函数calcShannonEnt,以及选择最优特征索引函数chooseBestFeatureToSplit,具体代码如下:
# -*- coding: utf-8 -*- """ Created on Thu Aug 2 17:09:34 2018 决策树ID3,C4.5的实现 @author: weixw """ from math import log import operator #原始数据 def createDataSet(): dataSet = [[1, 1, ‘yes‘], [1, 1, ‘yes‘], [1, 0, ‘no‘], [0, 1, ‘no‘], [0, 1, ‘no‘]] labels = [‘no surfacing‘,‘flippers‘] return dataSet, labels #多数表决器 #列中相同值数量最多为结果 def majorityCnt(classList): classCounts = {} for value in classList: if(value not in classCounts.keys()): classCounts[value] = 0 classCounts[value] +=1 sortedClassCount = sorted(classCounts.iteritems(),key = operator.itemgetter(1),reverse =True) return sortedClassCount[0][0] #划分数据集 #dataSet:原始数据集 #axis:进行分割的指定列索引 #value:指定列中的值 def splitDataSet(dataSet,axis,value): retDataSet= [] for featDataVal in dataSet: if featDataVal[axis] == value: #下面两行去除某一项指定列的值,很巧妙有没有 reducedFeatVal = featDataVal[:axis] reducedFeatVal.extend(featDataVal[axis+1:]) retDataSet.append(reducedFeatVal) return retDataSet #计算香农熵 #columnIndex = -1表示获取数据集每一项的最后一列的标签值 #其他表示获取特征列 def calcShannonEnt(columnIndex, dataSet): #数据集总项数 numEntries = len(dataSet) #标签计数对象初始化 labelCounts = {} for featDataVal in dataSet: #获取数据集每一项的最后一列的标签值 currentLabel = featDataVal[columnIndex] #如果当前标签不在标签存储对象里,则初始化,然后计数 if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 #熵初始化 shannonEnt = 0.0 #遍历标签对象,求概率,计算熵 for key in labelCounts.keys(): prop = labelCounts[key]/float(numEntries) shannonEnt -= prop*log(prop,2) return shannonEnt #通过信息增益,选出最优特征列索引(ID3) def chooseBestFeatureToSplit(dataSet): #计算特征个数,dataSet最后一列是标签属性,不是特征量 numFeatures = len(dataSet[0])-1 #计算初始数据香农熵 baseEntropy = calcShannonEnt(-1, dataSet) #初始化信息增益,最优划分特征列索引 bestInfoGain = 0.0 bestFeatureIndex = -1 for i in range(numFeatures): #获取每一列数据 featList = [example[i] for example in dataSet] #将每一列数据去重 uniqueVals = set(featList) newEntropy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet,i,value) #计算条件概率 prob = len(subDataSet)/float(len(dataSet)) #计算条件熵 newEntropy +=prob*calcShannonEnt(-1, subDataSet) #计算信息增益 infoGain = baseEntropy - newEntropy if(infoGain > bestInfoGain): bestInfoGain = infoGain bestFeatureIndex = i return bestFeatureIndex #通过信息增益率,选出最优特征列索引(C4.5) def chooseBestFeatureToSplitOfFurther(dataSet): #计算特征个数,dataSet最后一列是标签属性,不是特征量 numFeatures = len(dataSet[0])-1 #计算初始数据香农熵H(Y) baseEntropy = calcShannonEnt(-1, dataSet) #初始化信息增益,最优划分特征列索引 bestInfoGainRatio = 0.0 bestFeatureIndex = -1 for i in range(numFeatures): #获取每一特征列香农熵H(X) featEntropy = calcShannonEnt(i, dataSet) #获取每一列数据 featList = [example[i] for example in dataSet] #将每一列数据去重 uniqueVals = set(featList) newEntropy = 0.0 for value in uniqueVals: subDataSet = splitDataSet(dataSet,i,value) #计算条件概率 prob = len(subDataSet)/float(len(dataSet)) #计算条件熵 newEntropy +=prob*calcShannonEnt(-1, subDataSet) #计算信息增益 infoGain = baseEntropy - newEntropy #计算信息增益率 infoGainRatio = infoGain/float(featEntropy) if(infoGainRatio > bestInfoGainRatio): bestInfoGainRatio = infoGainRatio bestFeatureIndex = i return bestFeatureIndex #决策树创建 def createTree(dataSet,labels): #获取标签属性,dataSet最后一列,区别于labels标签名称 classList = [example[-1] for example in dataSet] #树极端终止条件判断 #标签属性值全部相同,返回标签属性第一项值 if classList.count(classList[0]) == len(classList): return classList[0] #没有特征,只有标签列(1列) if len(dataSet[0]) == 1: #返回实例数最大的类 return majorityCnt(classList) # #获取最优特征列索引ID3 # bestFeatureIndex = chooseBestFeatureToSplit(dataSet) #获取最优特征列索引C4.5 bestFeatureIndex = chooseBestFeatureToSplitOfFurther(dataSet) #获取最优索引对应的标签名称 bestFeatureLabel = labels[bestFeatureIndex] #创建根节点 myTree = {bestFeatureLabel:{}} #去除最优索引对应的标签名,使labels标签能正确遍历 del(labels[bestFeatureIndex]) #获取最优列 bestFeature = [example[bestFeatureIndex] for example in dataSet] uniquesVals = set(bestFeature) for value in uniquesVals: #子标签名称集合 subLabels = labels[:] #递归 myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet,bestFeatureIndex,value),subLabels) return myTree #获取分类结果 #inputTree:决策树字典 #featLabels:标签列表 #testVec:测试向量 例如:简单实例下某一路径 [1,1] => yes(树干值组合,从根结点到叶子节点) def classify(inputTree,featLabels,testVec): #获取根结点名称,将dict转化为list firstSide = list(inputTree.keys()) #根结点名称String类型 firstStr = firstSide[0] #获取根结点对应的子节点 secondDict = inputTree[firstStr] #获取根结点名称在标签列表中对应的索引 featIndex = featLabels.index(firstStr) #由索引获取向量表中的对应值 key = testVec[featIndex] #获取树干向量后的对象 valueOfFeat = secondDict[key] #判断是子结点还是叶子节点:子结点就回调分类函数,叶子结点就是分类结果 #if type(valueOfFeat).__name__==‘dict‘: 等价 if isinstance(valueOfFeat, dict): if isinstance(valueOfFeat, dict): classLabel = classify(valueOfFeat,featLabels,testVec) else: classLabel = valueOfFeat return classLabel #将决策树分类器存储在磁盘中,filename一般保存为txt格式 def storeTree(inputTree,filename): import pickle fw = open(filename,‘wb+‘) pickle.dump(inputTree,fw) fw.close() #将瓷盘中的对象加载出来,这里的filename就是上面函数中的txt文件 def grabTree(filename): import pickle fr = open(filename,‘rb‘) return pickle.load(fr)
‘‘‘ Created on Oct 14, 2010 @author: Peter Harrington ‘‘‘ import matplotlib.pyplot as plt decisionNode = dict(boxstyle="sawtooth", fc="0.8") leafNode = dict(boxstyle="round4", fc="0.8") arrow_args = dict(arrowstyle="<-") #获取树的叶子节点 def getNumLeafs(myTree): numLeafs = 0 #dict转化为list firstSides = list(myTree.keys()) firstStr = firstSides[0] secondDict = myTree[firstStr] for key in secondDict.keys(): #判断是否是叶子节点(通过类型判断,子类不存在,则类型为str;子类存在,则为dict) if type(secondDict[key]).__name__==‘dict‘:#test to see if the nodes are dictonaires, if not they are leaf nodes numLeafs += getNumLeafs(secondDict[key]) else: numLeafs +=1 return numLeafs #获取树的层数 def getTreeDepth(myTree): maxDepth = 0 #dict转化为list firstSides = list(myTree.keys()) firstStr = firstSides[0] secondDict = myTree[firstStr] for key in secondDict.keys(): if type(secondDict[key]).__name__==‘dict‘:#test to see if the nodes are dictonaires, if not they are leaf nodes thisDepth = 1 + getTreeDepth(secondDict[key]) else: thisDepth = 1 if thisDepth > maxDepth: maxDepth = thisDepth return maxDepth def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords=‘axes fraction‘, xytext=centerPt, textcoords=‘axes fraction‘, va="center", ha="center", bbox=nodeType, arrowprops=arrow_args ) def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0] yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on numLeafs = getNumLeafs(myTree) #this determines the x width of this tree depth = getTreeDepth(myTree) firstSides = list(myTree.keys()) firstStr = firstSides[0] #the text label for this node should be this cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decisionNode) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD for key in secondDict.keys(): if type(secondDict[key]).__name__==‘dict‘:#test to see if the nodes are dictonaires, if not they are leaf nodes plotTree(secondDict[key],cntrPt,str(key)) #recursion else: #it‘s a leaf node print the leaf node plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD #if you do get a dictonary you know it‘s a tree, and the first element will be another dict #绘制决策树 def createPlot(inTree): fig = plt.figure(1, facecolor=‘white‘) fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #no ticks #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses plotTree.totalW = float(getNumLeafs(inTree)) plotTree.totalD = float(getTreeDepth(inTree)) plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; plotTree(inTree, (0.5,1.0), ‘‘) plt.show() #绘制树的根节点和叶子节点(根节点形状:长方形,叶子节点:椭圆形) #def createPlot(): # fig = plt.figure(1, facecolor=‘white‘) # fig.clf() # createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses # plotNode(‘a decision node‘, (0.5, 0.1), (0.1, 0.5), decisionNode) # plotNode(‘a leaf node‘, (0.8, 0.1), (0.3, 0.8), leafNode) # plt.show() def retrieveTree(i): listOfTrees =[{‘no surfacing‘: {0: ‘no‘, 1: {‘flippers‘: {0: ‘no‘, 1: ‘yes‘}}}}, {‘no surfacing‘: {0: ‘no‘, 1: {‘flippers‘: {0: {‘head‘: {0: ‘no‘, 1: ‘yes‘}}, 1: ‘no‘}}}} ] return listOfTrees[i] #thisTree = retrieveTree(0) #createPlot(thisTree) #createPlot() #myTree = retrieveTree(0) #numLeafs =getNumLeafs(myTree) #treeDepth =getTreeDepth(myTree) #print(u"叶子节点数目:%d"% numLeafs) #print(u"树深度:%d"%treeDepth)
# -*- coding: utf-8 -*- """ Created on Fri Aug 3 19:52:10 2018 @author: weixw """ import myTrees as mt import treePlotter as tp #测试 dataSet, labels = mt.createDataSet() #copy函数:新开辟一块内存,然后将list的所有值复制到新开辟的内存中 labels1 = labels.copy() #createTree函数中将labels1的值改变了,所以在分类测试时不能用labels1 myTree = mt.createTree(dataSet,labels1) #保存树到本地 mt.storeTree(myTree,‘myTree.txt‘) #在本地磁盘获取树 myTree = mt.grabTree(‘myTree.txt‘) print(u"采用C4.5算法的决策树结果") print (u"决策树结构:%s"%myTree) #绘制决策树 print(u"绘制决策树:") tp.createPlot(myTree) numLeafs =tp.getNumLeafs(myTree) treeDepth =tp.getTreeDepth(myTree) print(u"叶子节点数目:%d"% numLeafs) print(u"树深度:%d"%treeDepth) #测试分类 简单样本数据3列 labelResult =mt.classify(myTree,labels,[1,1]) print(u"[1,1] 测试结果为:%s"%labelResult) labelResult =mt.classify(myTree,labels,[1,0]) print(u"[1,0] 测试结果为:%s"%labelResult)
运行结果
不要让懒惰占据你的大脑,不要让妥协拖垮你的人生。青春就是一张票,能不能赶上时代的快车,你的步伐掌握在你的脚下。
相关推荐
lwnylslwnyls 2020-11-06
赶路人儿 2020-11-02
机器学习之家 2020-11-10
mori 2020-11-06
jaybeat 2020-11-17
jaybeat 2020-11-02
changyuanchn 2020-11-01
Micusd 2020-11-19
人工智能 2020-11-19
81510295 2020-11-17
flyfor0 2020-11-16
lgblove 2020-11-16
Pokemogo 2020-11-16
Pokemogo 2020-11-16
clong 2020-11-13
lizhengjava 2020-11-13
ohbxiaoxin 2020-11-13
Icevivian 2020-11-13
EchoYY 2020-11-12