机器学习:使用Matplotlib注解绘制树形图以及实例运用
发布日期:2021-06-30 15:41:22 浏览次数:2 分类:技术文章

本文共 8854 字,大约阅读时间需要 29 分钟。

 


前言

内容承接上篇文章:

这篇补充上篇没有说到的知识点以及运用。


 

一、增益率

采用与信息增益相同的表达式,增益率定义为:

其中:

称为属性a的“固有值”(intrinsic value)[Quinlan, 1993]。属性a的可能取值数目越多(即V越大),则IV(a)的值通常会越大。

不过,增益率准则对于可取值数目少的属性又有所偏好,因此,C4.5算法并不是直接选择增益率最大的候选划分属性,而使用一个启发式算法:先从候选划分属性中找出信息增益高出平均水平的属性,再从中选择增益率最高的。

二、基尼系数

CART(Classification and Regression Tree)决策树[Breiman et al., 1984]则使用“基尼系数”(Gini index)来选择划分属性。数据集D的纯度可用基尼系数度量如下:

 

直观地讲,Gini(D)反映了从数据集D中随机抽取两个样本,其类别标记不一致的概率。因此,Gini(D)越小,D的纯度越高。

在此基础上,给出属性a的基尼指数:

于是,我们可以选择基尼指数最小的属性作为最优划分属性。 

三、决策树的剪枝处理

决策树生成算法递归地产生决策树,直到不能继续下去未为止。这样产生的树往往对训练数据的分类很准确,但对未知的测试数据的分类却没有那么准确,即出现过拟合现象。过拟合的原因在于学习时过多地考虑如何提高对训练数据的正确分类,从而构建出过于复杂的决策树。解决这个问题的办法是考虑决策树的复杂度,对已生成的决策树进行简化。

决策树的剪枝的基本策略主要有:

  • 预剪枝(prepruning):在决策树生成过程中,对每个节点在划分前先进行估计,若当前节点不能提升决策树的泛化性能,则停止划分并将当前节点标记为叶节点;
  • 后剪枝(postpruning):先从训练集生成一棵决策树,然后自底向上考察非叶节点,若将该节点替换为叶节点能提升决策树的泛化性能,则将该节点替换为叶子节点。

为考察泛化能力,可以预留一部分“验证集”以进行模型评估。

值得注意的是,预剪枝虽然显著减少了训练时间和测试时间的开销,但却带来了欠拟合的风险。因为有些分支可能在当前划分无法提升泛化性能,却在后续划分中可以做到。而后剪枝决策树在一般情形下欠拟合风险更小,且泛化性能往往优于预剪枝决策树,不过代价是要付出大得多的训练时间开销。

顺便一提,经过剪枝后,仅有一层划分的决策树,也被称为“决策树桩”(decision stump)。

实现方式:极小化决策树整体的损失函数或代价函数来实现

决策树学习的损失函数定义为:

四、决策树连续与缺失值处理

1.连续值处理

前面讨论的是基于离散属性生成的决策树。然而在现实任务中,时常会遇到连续属性,此时便不能直接根据连续属性的值来划分节点。需要对连续属性离散化。

最简单的策略是二分法(bi-partition)。给定样本集D和连续属性a,假定a在D上出现n个不同的取值,从小到大排序记为{a1, a2, ..., an}。于是,对于连续属性a,可以考虑n-1个元素的候选划分点集合Ta = {(ai+ai+1)/2 | 1 ≤ i ≤ n-1}。于是,在此基础上,可以对信息增益加以改造。

 

2.缺失值处理

 现实任务中常会遇到不完整样本,即样本的某些属性值确实。例如由于诊测成本、隐私保护等因素,患者的医疗数据在某些属性上的取值(如HIV测试结果)。如果简单放弃不完整样本,显然是对数据的极大浪费。为充分利用数据,需要解决两个问题:

  • 如何在属性值缺失的情况下进行划分属性选择?
  • 给定划分属性,若样本在该属性上缺失,如何对样本进行划分?

给定训练集D和属性a,令表示D中在属性a上没有缺失的样本子集。对于问题1,显然仅可以根据来判断属性a的优劣。假定属性a有V个可取的值{a1, a2, ..., aV},令表示中在属性a上取值为av的样本子集,表示中属于第k类(k=1,2,...,|Y|)的样本子集。则显然有

假定为每个样本x赋以权重ωx,并定义

 

 

显然,

基于上述定义,可以将信息增益公式推广为

对于问题2,若样本x在属性a上的取值已知,则划入对应子节点,并保持样本权值即可;若取值未知,则同时划入所有子节点,且样本权值在属性值av对应的子节点中调整为

五、多变量决策树

将每个属性视为坐标空间的一个坐标轴,则由d个属性描述的样本,对应于d维空间中的一个数据点。对样本分类,意味着在此坐标空间中寻找不容类样本间的分类边界。而决策树所形成的分类边界 有一个明显的特点:轴平行(axis-parallel),即其分类边界由若干个与轴平行的分段组成。这一特点的好处在于有较好的可解释性,但为了近似比较复杂的分类边界,会导致决策树过于复杂。为解决此问题,引入多变量决策树。

多变量决策树(multivariate decision tree)就能实现用斜线划分、甚至更复杂的划分。在此类决策树中,非叶节点不再仅是某个属性,而是对属性的线性组合进行测试,i.e.每个非叶节点都是一个形如的线性分类器。下面两张图给出了决策树和多变量决策树分类结果的对比。

 六、Matplotlib注解绘制树形图

 1.使用文本注解绘制树节点

先做个样例演示一下将来树的形状:

import matplotlib.pyplot as plt#定义文本框和箭头格式decisionNode = dict(boxstyle="sawtooth",fc="0.8")leafNode = dict(boxstyle="round4",fc="0.8")arrow_args = dict(arrowstyle="<-")def plotNode(nodeTxt, centerPt, parentPt, nodeType):    createPlot.ax1.annotate(nodeTxt, xy=parentPt, \    xycoords='axes fraction',    xytext=centerPt, textcoords='axes fraction',\    va="center",ha="center", bbox=nodeType, arrowprops=arrow_args)def createPlot():    fig = plt.figure(1,facecolor='white')    fig.clf()    createPlot.ax1 = plt.subplot(111,frameon=False)    plotNode('decisionNode',(0.5,0.1),(0.1,0.5),decisionNode)    plotNode('leafNode',(0.8,0.1),(0.3,0.8),leafNode)    plt.show()

 

plt函数不用我多解释了吧。这里讲一下annotate 函数,蛮好玩的:

    import matplotlib.pyplot as plt

    # plt.annotate(str, xy=data_point_position, xytext=annotate_position, 
    #              va="center",  ha="center", xycoords="axes fraction", 
    #              textcoords="axes fraction", bbox=annotate_box_type, arrowprops=arrow_style)
    # str是给数据点添加注释的内容,支持输入一个字符串
    # xy=是要添加注释的数据点的位置
    # xytext=是注释内容的位置
    # bbox=是注释框的风格和颜色深度,fc越小,注释框的颜色越深,支持输入一个字典
    # va="center",  ha="center"表示注释的坐标以注释框的正中心为准,而不是注释框的左下角(v代表垂直方向,h代表水平方向)
    # xycoords和textcoords可以指定数据点的坐标系和注释内容的坐标系,通常只需指定xycoords即可,textcoords默认和xycoords相同
    # arrowprops可以指定箭头的风格支持,输入一个字典
    # plt.annotate()的详细参数可用__doc__查看,如:print(plt.annotate.__doc__)
 

不懂的可以看看这个博客:

2.构造注解树

def getNumLeafs(myTree):    numLeafs = 0    # firstStr = myTree.key()[0]    firstSides = list(myTree.keys())    firstStr = firstSides[0]  # 找到输入的第一个元素    secondDict = myTree[firstStr]    #根据判断节点是否为字典类型来判断是否为叶子节点。    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':            numLeafs += getNumLeafs(secondDict[key])        else:            numLeafs += 1    return numLeafsdef getTreeDepth(myTree):    maxDepth = 0    firstSides = list(myTree.keys())    firstStr = firstSides[0]  # 找到输入的第一个元素    #firstStr = myTree.keys()[0]    secondDict = myTree[firstStr]    # 根据判断节点是否为字典来判断是否进行了分支。根据分支计算层数    for key in secondDict.keys():        if type(secondDict[key]).__name__=='dict':            thisDepth = 1+getTreeDepth(secondDict[key])        else:            thisDepth = 1        if thisDepth > maxDepth:            maxDepth = thisDepth    return maxDepth

 原书上是:

# firstStr = myTree.key()[0]

但是py3以上就不能实现这个功能了,很简单用list将myTreekeys转为列表,再将第一个元素转给fristStr就好了。

def retrieveTree(i):    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},    {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}    ]    return listOfTrees[i]

 先给一些树进行测试吧。

mytree=retrieveTree(0)print(mytree)print(getTreeDepth(mytree))print(getNumLeafs(mytree))

OK,接下来就是拼接了。just code it!

 设置一个中间文本用来记录对于该特征对应属性。

def plotMidText(cntrPt, parentPt, txtString):    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] #中间文本坐标的x    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] #中间文本坐标的y    createPlot.ax1.text(xMid, yMid, txtString)

 想整好看点的可以修改一下text函数,这里附赠一下用法:

# plt.text()函数用法# 第一个参数是x轴坐标# 第二个参数是y轴坐标# 第三个参数是要显式的内容# alpha 设置字体的透明度# family 设置字体# size 设置字体的大小# style 设置字体的风格# wight 字体的粗细# bbox 给字体添加框,alpha 设置框体的透明度, facecolor 设置框体的颜色
def plotTree(myTree, parentPt, nodeTxt):    numLeafs = getNumLeafs(myTree) #计算树的宽度    depth = getTreeDepth(myTree)   #计算树的高度    firstSides = list(myTree.keys())    firstStr = firstSides[0]  # 找到输入的第一个元素    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)    plotMidText(cntrPt, parentPt, nodeTxt) #特征属性值文本    plotNode(firstStr, cntrPt, parentPt, decisionNode) #叶子节点    secondDict = myTree[firstStr]    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD #减少y的便偏移    for key in secondDict.keys():        if type(secondDict[key]).__name__ == 'dict':            plotTree(secondDict[key], cntrPt, str(key))        else:            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

函数plotTree()首先计算树的宽和高。

def createPlot(inTree):    fig = plt.figure(1, facecolor='white')    fig.clf()# Clear figure清除所有轴,但是窗口打开,这样它可以被重复使用     # cla() :Clear axis即清除当前图形中的当前活动轴。其他轴不受影响    axprops = dict(xticks=[], yticks=[])    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    plotTree.totalW = float(getNumLeafs(inTree))    plotTree.totalD = float(getTreeDepth(inTree))    plotTree.xOff = -0.5 / plotTree.totalW    plotTree.yOff = 1.0    plotTree(inTree, (0.5, 1.0), '')    plt.show()

 

全局变量plotTree.totalW存储树的宽度,全以将树绘制在水平方向和垂直方向的中心位置。函数plotTree()也是个递归函数。树的宽度用于计算放置判断节点的位置,主要的计算原则是将它放置所有叶子节点在中间,而不仅仅是它子节点的中间。同时我们使用两个全局变量plotTree.xOff和plotTree0.yOff追踪已经绘制的节点位置,以及放置下一个节点的恰当位置,以及放置下一个节点的恰当位置。

(plt函数还蛮多的,遇到一个学一个先)

(1)figure语法说明

figure(num=None, figsize=None, dpi=None, facecolor=None, edgecolor=None, frameon=True)

  • num:图像编号或名称,数字为编号 ,字符串为名称
  • figsize:指定figure的宽和高,单位为英寸;
  • dpi参数指定绘图对象的分辨率,即每英寸多少个像素,缺省值为80      1英寸等于2.5cm,A4纸是 21*30cm的纸张 
  • facecolor:背景颜色
  • edgecolor:边框颜色
  • frameon:是否显示边框

 

最后实现:

七、测试算法

依靠训练数据构造了决策树之后,我们可以将它用于实际数据的分类。在执行数据分类时,需要使用决策树以及用于构造决策树的标签向量。然后,程序比较测试数据与决策树上的数值递归执行该过程指导进入叶子节点;最后将测试数据定义为叶子节点所属的类型。

接下来我们使用决策树的分类函数:(将代码加入trees.py,也就是上一篇trees中)

def classify(inputTree,featLabels,testVec):    firstSide = list(inputTree.keys())    firstStr = firstSide[0]    secondDict = inputTree[firstStr]    featIndex = featLabels.index(firstStr) #index() 函数用于从列表中找出某个值第一个匹配项的索引位置。->test3    print(featIndex)    print(testVec[featIndex])    for key in secondDict.keys():        if testVec[featIndex] == key:            if type(secondDict[key]).__name__=='dict':                classLabel = classify(secondDict[key],featLabels,testVec)            else:   classLabel = secondDict[key]    return  classLabel

我们现在可以测试一下:

MydDat,labels=createDataSet()print(labels)mytree=treePlotter.retrieveTree(0)print(mytree)print(classify(mytree,labels,[1,0]))

 在存储带有特征的数据会面临一个问题:程序无法确定特征在数据集中的位置。特征标签李彪将帮助程序处理这个问题。使用index方法查找当前列表中第一个匹配firstStr变量的元素。然后代码递归遍历整个树,比较testVec变量中的值与树节点的值,如果到达叶子节点,则返回当前节点的分类标签。

八、决策树的存储

相信大家实现了决策树肯定都会有这样一个疑惑,如歌给出超大数据包括超多特征,那么采取递归的方式岂不是要好久才能遍历完?然而用创建好了的决策树解决分类问题,则可以很快完成。因此,为了节省时间,最好能够在每次执行分类时调用已经构造好的决策树。为了解决这个问题,需要使用Python模块pickle序列化对象。序列化对象可以在磁盘上保存对象,并在需要的时候读取出来。任何对象都可以执行序列化操作,字典对象也不例外。

def storeTree(inputTree,filename):    import  pickle    fw = open(filename,'w')    pickle.dump(inputTree,fw)    fw.close()def garbTree(filename):    import pickle    fr = open(filename)    return pickle.load(fr)

通过上面的代码,我们可以将分类器存储在硬盘上,二不用每次分化对数据分类时重新学习一遍。

九、使用决策树预测隐形眼镜类型

按流程走:收集数据->准备数据->分析数据->训练算法->测试算法->使用算法

隐形眼镜数据集是非常著名的数据集。我们看看其样式:

 

 

使用该决策树可能会有多种情况,其不同点在于划分方式,在于信息增益。我们可以使用剪枝算法进行优化,得到符合预期想法的树。

本章主要使用的是ID3算法,自身也存在着很多不足。

总结

要用好决策树还是蛮难的,多做几个类似项目预测试试。

参阅:

转载地址:https://jxnuxwt.blog.csdn.net/article/details/109178471 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:动态添加和删除Datanode的方法
下一篇:HbaseJAVA开发API导入jar包以及实现操作命令

发表评论

最新留言

关注你微信了!
[***.104.42.241]2024年04月09日 01时49分13秒