上節我們已經學習了如何從數據集中創建樹,然而字典的表示形式非常不易於理解,而且直接繪製圖形也比較困難。本節我們將使用Matplotlib庫創建樹形圖。決策樹的主要優點就是直觀易於理解,如果不能將其直觀地顯示出來,就無法發揮其優勢。雖然前面章節我們使用的圖形庫已經非常強大,但是Python並沒有提供繪製樹的工具,因此我們必須自己繪製樹形圖。本節我們將學習如何編寫代碼繪製如圖3-3所示的決策樹。
圖3-3 決策樹的範例
3.2.1 Matplotlib註解
Matplotlib提供了一個註解工具annotations
,非常有用,它可以在數據圖形上添加文本註釋。註解通常用於解釋數據的內容。由於數據上面直接存在文本描述非常醜陋,因此工具內嵌支持帶箭頭的劃線工具,使得我們可以在其他恰當的地方指向數據位置,並在此處添加描述信息,解釋數據內容。如圖3-4所示,在坐標(0.2, 0.1)的位置有一個點,我們將對該點的描述信息放在(0.35, 0.3)的位置,並用箭頭指向數據點(0.2, 0.1)。
圖3-4 Matplotlib註解示例
繪製還是圖形化
為什麼使用單詞「繪製」(plot)?為什麼在討論如何在圖形上顯示數據的時候不使用單詞「圖形化」(graph)?這裡存在一些語言上的差別,英語單詞graph在某些學科中具有特定的含義,如在應用數學中,一系列由邊連接在一起的對象或者節點稱為圖。節點的任意聯繫都可以通過邊來連接。在計算機科學中,圖是一種數據結構,用於表示數學上的概念。好在漢語並不存在這些混淆的概念,這裡就統一使用繪製樹形圖。
本書將使用Matplotlib的註解功能繪製樹形圖,它可以對文字著色並提供多種形狀以供選擇,而且我們還可以反轉箭頭,將它指向文本框而不是數據點。打開文本編輯器,創建名為treePlotter.py的新文件,然後輸入下面的程序代碼。
程序清單3-5 使用文本註解繪製樹節點
import matplotlib.pyplot as plt
#❶ (以下三行)定義文本框和箭頭格式
decisionNode = dict(box, fc=\"0.8\")
leafNode = dict(box, fc=\"0.8\")
arrow_args = dict(arrow)
#❷ (以下兩行)繪製帶箭頭的註解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt,
xycoords=\'axes fraction\',
xytext=centerPt, textcoords=\'axes fraction\',
va=\"center\", ha=\"center\", bbox=nodeType, arrowprops=arrow_args)
def createPlot:
fig = plt.figure(1, facecolor=\'white\')
fig.clf
reatePlot.ax1 = plt.subplot(111, frameon=False)
plotNode(\'決策節點\', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode(\'葉節點\', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show
這是第一個版本的createPlot
函數,與例子文件中的createPlot
函數有些不同,隨著內容的深入,我們將逐步添加缺失的代碼。代碼定義了樹節點格式的常量❶。然後定義plotNode
函數執行了實際的繪圖功能,該函數需要一個繪圖區,該區域由全局變量createPlot.ax1
定義。Python語言中所有的變量默認都是全局有效的,只要我們清楚知道當前代碼的主要功能,並不會引入太大的麻煩。最後定義createPlot
函數,它是這段代碼的核心。createPlot
函數首先創建了一個新圖形並清空繪圖區,然後在繪圖區上繪製兩個代表不同類型的樹節點,後面我們將用這兩個節點繪製樹形圖。
為了測試上面代碼的實際輸出結果,打開Python命令提示符,導入treePlotter
模塊:
>>> import treePlotter
>>> treePlotter.createPlot
程序的輸出結果如圖3-5所示,我們也可以改變函數plotNode
❷,觀察圖中x、y位置如何變化。
圖3-5 函數plotNode
的例子
現在我們已經掌握了如何繪製樹節點,下面將學習如何繪製整棵樹。
3.2.2 構造註解樹
繪製一棵完整的樹需要一些技巧。我們雖然有x、y坐標,但是如何放置所有的樹節點卻是個問題。我們必須知道有多少個葉節點,以便可以正確確定x軸的長度;我們還需要知道樹有多少層,以便可以正確確定y軸的高度。這裡我們定義兩個新函數getNumLeafs
和getTreeDepth
,來獲取葉節點的數目和樹的層數,參見程序清單3-6,並將這兩個函數添加到文件treePlotter.py中。
程序清單3-6 獲取葉節點的數目和樹的層數
def getNumLeafs(myTree):
numLeafs = 0
firstStr = myTree.keys[0]
secondDict = myTree[firstStr]
for key in secondDict.keys:
#❶ (以下三行)測試節點的數據類型是否為字典
if type(secondDict[key]).__name__==\'dict\':
numLeafs += getNumLeafs(secondDict[key])
else: numLeafs +=1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = myTree.keys[0]
secondDict = myTree[firstStr]
for key in secondDict.keys:
if type(secondDict[key]).__name__==\'dict\':
thisDepth = 1 + getTreeDepth(secondDict[key])
else: thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
上述程序中的兩個函數具有相同的結構,後面我們也將使用到這兩個函數。這裡使用的數據結構說明了如何在Python字典類型中存儲樹信息。第一個關鍵字是第一次劃分數據集的類別標籤,附帶的數值表示子節點的取值。從第一個關鍵字出發,我們可以遍歷整棵樹的所有子節點。使用Python提供的type
函數可以判斷子節點是否為字典類型❶。如果子節點是字典類型,則該節點也是一個判斷節點,需要遞歸調用getNumLeafs
函數。getNumLeafs
函數遍歷整棵樹,累計葉子節點的個數,並返回該數值。第2個函數getTreeDepth
計算遍歷過程中遇到判斷節點的個數。該函數的終止條件是葉子節點,一旦到達葉子節點,則從遞歸調用中返回,並將計算樹深度的變量加一。為了節省大家的時間,函數retrieveTree
輸出預先存儲的樹信息,避免了每次測試代碼時都要從數據中創建樹的麻煩。
添加下面的代碼到文件treePlotter.py中:
def retrieveTree(i):
listOfTrees =[{\'no surfacing\': {0: \'no\', 1: {\'flippers\':
{0: \'no\', 1: \'yes\'}}}},
{\'no surfacing\': {0: \'no\', 1: {\'flippers\':
{0: {\'head\': {0: \'no\', 1: \'yes\'}}, 1: \'no\'}}}}
]
return listOfTrees[i]
保存文件treePlotter.py,在Python命令提示符下輸入下列命令:
>>> reload(treePlotter)
<module \'treePlotter\' from \'treePlotter.py\'>
>>> treePlotter.retrieveTree (1)
{\'no surfacing\': {0: \'no\', 1: {\'flippers\': {0: {\'head\': {0: \'no\', 1:
\'yes\'}}, 1: \'no\'}}}}</pre>
>>> myTree = treePlotter.retrieveTree (0)
>>> treePlotter.getNumLeafs(myTree)
3
>>> treePlotter.getTreeDepth(myTree)
2
函數retrieveTree
主要用於測試,返回預定義的樹結構。上述命令中調用getNumLeafs
函數返回值為3,等於樹0的葉子節點數;調用getTreeDepths
函數也能夠正確返回樹的層數。
現在我們可以將前面學到的方法組合在一起,繪製一棵完整的樹。最終的結果如圖3-6所示,但是沒有x和y軸標籤。
圖3-6 簡單數據集繪製的樹形圖
打開文本編輯器,將程序清單3-7的內容添加到treePlotter.py文件中。注意,前文已經在文件中定義了函數createPlot
,此處我們需要更新這部分代碼。
程序清單3-7 plotTree
函數
#❶ (以下四行)在父子節點間填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va=\"center\", ha=\"center\", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):
#❷(以下兩行)計算寬和高
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = myTree.keys[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
#❸ 標記子節點屬性值
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
#❹(以下兩行)減小y偏移
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
for key in secondDict.keys:
if type(secondDict[key]).__name__==\'dict\':
plotTree(secondDict[key],cntrPt,str(key)) #recursion
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),, cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor=\'white\')
fig.clf
axprops = dict(xticks=, yticks=)
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
plotTree(inTree, (0.5,1.0), \'\')
plt.show
函數createPlot
是我們使用的主函數,它調用了plotTree
,函數plotTree
又依次調用了前面介紹的函數和plotMidText
。繪製樹形圖的很多工作都是在函數plotTree
中完成的,函數plotTree
首先計算樹的寬和高❷。全局變量plotTree.totalW
存儲樹的寬度,全局變量plotTree.totalD
存儲樹的深度,我們使用這兩個變量計算樹節點的擺放位置,這樣可以將樹繪製在水平方向和垂直方向的中心位置。與程序清單3-6中的函數getNumLeafs
和getTreeDepth
類似,函數plotTree
也是個遞歸函數。樹的寬度用於計算放置判斷節點的位置,主要的計算原則是將它放在所有葉子節點的中間,而不僅僅是它子節點的中間。同時我們使用兩個全局變量plotTree.xOff
和plotTree.yOff
追蹤已經繪製的節點位置,以及放置下一個節點的恰當位置。另一個需要說明的問題是,繪製圖形的x軸有效範圍是0.0到1.0,y軸有效範圍也是0.0~1.0。為了方便起見,圖3-6給出具體坐標值,實際輸出的圖形中並沒有xy坐標。通過計算樹包含的所有葉子節點數,劃分圖形的寬度,從而計算得到當前節點的中心位置,也就是說,我們按照葉子節點的數目將x軸劃分為若幹部分。按照圖形比例繪製樹形圖的最大好處是無需關心實際輸出圖形的大小,一旦圖形大小發生了變化,函數會自動按照圖形大小重新繪製。如果以像素為單位繪製圖形,則縮放圖形就不是一件簡單的工作。
接著,繪出子節點具有的特徵值,或者沿此分支向下的數據實例必須具有的特徵值❸。使用函數plotMidText
計算父節點和子節點的中間位置,並在此處添加簡單的文本標籤信息❶。
然後,按比例減少全局變量plotTree.yOff
,並標注此處將要繪製子節點❹,這些節點既可以是葉子節點也可以是判斷節點,此處需要只保存繪製圖形的軌跡。因為我們是自頂向下繪製圖形,因此需要依次遞減y坐標值,而不是遞增y坐標值。然後程序採用函數getNumLeafs
和getTreeDepth
以相同的方式遞歸遍歷整棵樹,如果節點是葉子節點則在圖形上畫出葉子節點,如果不是葉子節點則遞歸調用plotTree
函數。在繪製了所有子節點之後,增加全局變量Y的偏移。
程序清單3-7的最後一個函數是createPlot
,它創建繪圖區,計算樹形圖的全局尺寸,並調用遞歸函數plotTree
。
現在我們可以驗證一下實際的輸出效果。添加上述代碼到文件treePlotter.py之後,在Python命令提示符下輸入下列命令:
>>> reload(treePlotter)
<module \'treePlotter\' from \'treePlotter.pyc\'>
>>> myTree=treePlotter.retrieveTree (0)
>>> treePlotter.createPlot(myTree)
輸出效果如圖3-6所示,但是沒有坐標軸標籤。接著按照如下命令變更字典,重新繪製樹形圖:
>>> myTree[\'no surfacing\'][3]=\'maybe\'
>>> myTree
{\'no surfacing \': {0: \'no\', 1: {\'flippers\': {0: \'no\', 1: \'yes\'}}, 3:
\'maybe\'}}
>>> treePlotter.createPlot(myTree)
輸出效果如圖3-7所示,有點像一個無頭的簡筆畫。你也可以在樹字典中隨意添加一些數據,並重新繪製樹形圖觀察輸出結果的變化。
到目前為止,我們已經學習了如何構造決策樹以及繪製樹形圖的方法,下節我們將實際使用這些方法,並從數據和算法中得到某些新知識。
圖3-7 超過兩個分支的樹形圖