讀古今文學網 > 機器學習實戰 > 9.4 樹剪枝 >

9.4 樹剪枝

一棵樹如果節點過多,表明該模型可能對數據進行了「過擬合」。那麼,如何判斷是否發生了過擬合?前面章節中使用了測試集上某種交叉驗證技術來發現過擬合,決策樹亦是如此。本節將對此進行討論,並分析如何避免過擬合。

通過降低決策樹的複雜度來避免過擬合的過程稱為剪枝(pruning)。其實本章前面已經進行過剪枝處理。在函數chooseBestSplit中的提前終止條件,實際上是在進行一種所謂的預剪枝(prepruning)操作。另一種形式的剪枝需要使用測試集和訓練集,稱作後剪枝(postpruning)。本節將分析後剪枝的有效性,但首先來看一下預剪枝的不足之處。

9.4.1 預剪枝

上節兩個簡單實驗的結果還是令人滿意的,但背後存在一些問題。樹構建算法其實對輸入的參數tolStolN非常敏感,如果使用其他值將不太容易達到這麼好的效果。為了說明這一點,在Python提示符下輸入如下命令:

>>> regTrees.createTree(myMat,ops=(0,1))
  

與上節中只包含兩個節點的樹相比,這裡構建的樹過於臃腫,它甚至為數據集中每個樣本都分配了一個葉節點。

圖9-3中的散點圖,看上去與圖9-1非常相似。但如果仔細地觀察y軸就會發現,前者的數量級是後者的100倍。這將不是問題,對吧?現在用該數據來構建一棵新的樹(數據存放在ex2.txt中),在Python提示符下輸入以下命令:

>>> myDat2=regTrees.loadDataSet(\'ex2.txt\')
>>> myMat2=mat(myDat2)
>>> regTrees.createTree(myMat2)
{\'spInd\': 0, \'spVal\': matrix([[ 0.499171]]), \'right\': {\'spInd\': 0,
 \'spVal\': matrix([[ 0.457563]]), \'right\': -3.6244789069767438,
 \'left\': 7.9699461249999999}, \'l
.
.
0, \'spVal\': matrix([[ 0.958512]]), \'right\': 112.42895575000001,
\'left\': 105.248
2350000001}}}}
  

圖9-3 將圖9-1的數據的y軸放大100倍後的新數據集

不知你注意到沒有,從圖9-1數據集構建出來的樹只有兩個葉節點,而這裡構建的新樹則有很多葉節點。產生這個現象的原因在於,停止條件tolS對誤差的數量級十分敏感。如果在選項中花費時間並對上述誤差容忍度取平方值,或許也能得到僅有兩個葉節點組成的樹:

>>> regTrees.createTree(myMat2,ops=(10000,4))
{\'spInd\': 0, \'spVal\': matrix([[ 0.499171]]), \'right\': -2.6377193297872341,
 \'left\': 101.35815937735855}
  

然而,通過不斷修改停止條件來得到合理結果並不是很好的辦法。事實上,我們常常甚至不確定到底需要尋找什麼樣的結果。這正是機器學習所關注的內容,計算機應該可以給出總體的概貌。

下節將討論後剪枝,即利用測試集來對樹進行剪枝。由於不需要用戶指定參數,後剪枝是一個更理想化的剪枝方法。

9.4.2 後剪枝

使用後剪枝方法需要將數據集分成測試集和訓練集。首先指定參數,使得構建出的樹足夠大、足夠複雜,便於剪枝。接下來從上而下找到葉節點,用測試集來判斷將這些葉節點合併是否能降低測試誤差。如果是的話就合併。

函數prune的偽代碼如下:

基於已有的樹切分測試數據:
    如果存在任一子集是一棵樹,則在該子集遞歸剪枝過程
    計算將當前兩個葉節點合併後的誤差
    計算不合並的誤差
    如果合併會降低誤差的話,就將葉節點合併
  

為瞭解實際效果,打開regTrees.py並輸入程序清單9-3的代碼。

程序清單9-3 回歸樹剪枝函數

def isTree(obj):
    return (type(obj).__name__==\'dict\')

def getMean(tree):
    if isTree(tree[\'right\']): tree[\'right\'] = getMean(tree[\'right\'])
    if isTree(tree[\'left\']): tree[\'left\'] = getMean(tree[\'left\'])
    return (tree[\'left\']+tree[\'right\'])/2.0

def prune(tree, testData):
   #❶  沒有測試數據則對樹進行塌陷處理
    if shape(testData)[0] == 0: return getMean(tree)
    if (isTree(tree[\'right\']) or isTree(tree[\'left\'])):
        lSet, rSet = binSplitDataSet(testData, tree[\'spInd\'],tree[\'spVal\'])
    if isTree(tree[\'left\']): tree[\'left\'] = prune(tree[\'left\'], lSet)
    if isTree(tree[\'right\']): tree[\'right\'] = prune(tree[\'right\'], rSet)
    if not isTree(tree[\'left\']) and not isTree(tree[\'right\']):
        lSet, rSet = binSplitDataSet(testData, tree[\'spInd\'],tree[\'spVal\'])
        errorNoMerge = sum(power(lSet[:,-1] - tree[\'left\'],2)) +sum(power(rSet[:,-1] - tree[\'right\'],2))
        treeMean = (tree[\'left\']+tree[\'right\'])/2.0
        errorMerge = sum(power(testData[:,-1] - treeMean,2))
        if errorMerge < errorNoMerge:
            print \"merging\"
            return treeMean
        else: return tree
    else: return tree   
 

程序清單9-3中包含三個函數:isTreegetMeanprune。其中isTree用於測試輸入變量是否是一棵樹,返回布爾類型的結果。換句話說,該函數用於判斷當前處理的節點是否是葉節點。

函數getMean是一個遞歸函數,它從上往下遍歷樹直到葉節點為止。如果找到兩個葉節點則計算它們的平均值。該函數對樹進行塌陷處理(即返回樹平均值),在prune函數中調用該函數時應明確這一點。

程序清單9-3的主函數是prune,它有兩個參數:待剪枝的樹與剪枝所需的測試數據testDataprune函數首先需要確認測試集是否為空❶。一旦非空,則反覆遞歸調用函數prune對測試數據進行切分。因為樹是由其他數據集(訓練集)生成的,所以測試集上會有一些樣本與原數據集樣本的取值範圍不同。一旦出現這種情況應當怎麼辦?數據發生過擬合應該進行剪枝嗎?或者模型正確不需要任何剪枝?這裡假設發生了過擬合,從而對樹進行剪枝。

接下來要檢查某個分支到底是子樹還是節點。如果是子樹,就調用函數prune來對該子樹進行剪枝。在對左右兩個分支完成剪枝之後,還需要檢查它們是否仍然還是子樹。如果兩個分支已經不再是子樹,那麼就可以進行合併。具體做法是對合併前後的誤差進行比較。如果合併後的誤差比不合並的誤差小就進行合併操作,反之則不合並直接返回。

接下來看看實際效果,將程序清單9-3的代碼添加到regTrees.py文件並保存,在Python提示符下輸入下面的命令:

>>> reload(regTrees)
<module \'regTrees\' from \'regTrees.pyc\'> 
  

為了創建所有可能中最大的樹,輸入如下命令:

>>> myTree=regTrees.createTree(myMat2, ops=(0,1))
  

輸入以下命令導入測試數據:

>>> myDatTest=regTrees.loadDataSet(\'ex2test.txt\')
>>> myMat2Test=mat(myDatTest) 
  

輸入以下命令,執行剪枝過程:

>>> regTrees.prune(myTree, myMat2Test)
merging
merging
merging
.
.
merging
{\'spInd\': 0, \'spVal\': matrix([[ 0.499171]]), \'right\': {\'spInd\': 0, \'spVal\':
.
.
01, \'left\': {\'spInd\': 0, \'spVal\': matrix([[ 0.960398]]), \'right\': 123.559747,
    \'left\': 112.386764}}}, \'left\': 92.523991499999994}}}} 
  

可以看到,大量的節點已經被剪枝掉了,但沒有像預期的那樣剪枝成兩部分,這說明後剪枝可能不如預剪枝有效。一般地,為了尋求最佳模型可以同時使用兩種剪枝技術。

下節將重用部分已有的樹構建代碼來創建一種新的樹。該樹仍採用二元切分,但葉節點不再是簡單的數值,取而代之的是一些線性模型。