一棵樹如果節點過多,表明該模型可能對數據進行了「過擬合」。那麼,如何判斷是否發生了過擬合?前面章節中使用了測試集上某種交叉驗證技術來發現過擬合,決策樹亦是如此。本節將對此進行討論,並分析如何避免過擬合。
通過降低決策樹的複雜度來避免過擬合的過程稱為剪枝(pruning)。其實本章前面已經進行過剪枝處理。在函數chooseBestSplit
中的提前終止條件,實際上是在進行一種所謂的預剪枝(prepruning)操作。另一種形式的剪枝需要使用測試集和訓練集,稱作後剪枝(postpruning)。本節將分析後剪枝的有效性,但首先來看一下預剪枝的不足之處。
9.4.1 預剪枝
上節兩個簡單實驗的結果還是令人滿意的,但背後存在一些問題。樹構建算法其實對輸入的參數tolS
和tolN
非常敏感,如果使用其他值將不太容易達到這麼好的效果。為了說明這一點,在Python提示符下輸入如下命令:
>>> regTrees.createTree(myMat,ops=(0,1))
與上節中只包含兩個節點的樹相比,這裡構建的樹過於臃腫,它甚至為數據集中每個樣本都分配了一個葉節點。
圖9-3中的散點圖,看上去與圖9-1非常相似。但如果仔細地觀察y軸就會發現,前者的數量級是後者的100倍。這將不是問題,對吧?現在用該數據來構建一棵新的樹(數據存放在ex2.txt
中),在Python提示符下輸入以下命令:
>>> myDat2=regTrees.loadDataSet(\'ex2.txt\')
>>> myMat2=mat(myDat2)
>>> regTrees.createTree(myMat2)
{\'spInd\': 0, \'spVal\': matrix([[ 0.499171]]), \'right\': {\'spInd\': 0,
\'spVal\': matrix([[ 0.457563]]), \'right\': -3.6244789069767438,
\'left\': 7.9699461249999999}, \'l
.
.
0, \'spVal\': matrix([[ 0.958512]]), \'right\': 112.42895575000001,
\'left\': 105.248
2350000001}}}}
圖9-3 將圖9-1的數據的y軸放大100倍後的新數據集
不知你注意到沒有,從圖9-1數據集構建出來的樹只有兩個葉節點,而這裡構建的新樹則有很多葉節點。產生這個現象的原因在於,停止條件tolS
對誤差的數量級十分敏感。如果在選項中花費時間並對上述誤差容忍度取平方值,或許也能得到僅有兩個葉節點組成的樹:
>>> regTrees.createTree(myMat2,ops=(10000,4))
{\'spInd\': 0, \'spVal\': matrix([[ 0.499171]]), \'right\': -2.6377193297872341,
\'left\': 101.35815937735855}
然而,通過不斷修改停止條件來得到合理結果並不是很好的辦法。事實上,我們常常甚至不確定到底需要尋找什麼樣的結果。這正是機器學習所關注的內容,計算機應該可以給出總體的概貌。
下節將討論後剪枝,即利用測試集來對樹進行剪枝。由於不需要用戶指定參數,後剪枝是一個更理想化的剪枝方法。
9.4.2 後剪枝
使用後剪枝方法需要將數據集分成測試集和訓練集。首先指定參數,使得構建出的樹足夠大、足夠複雜,便於剪枝。接下來從上而下找到葉節點,用測試集來判斷將這些葉節點合併是否能降低測試誤差。如果是的話就合併。
函數prune的偽代碼如下:
基於已有的樹切分測試數據: 如果存在任一子集是一棵樹,則在該子集遞歸剪枝過程 計算將當前兩個葉節點合併後的誤差 計算不合並的誤差 如果合併會降低誤差的話,就將葉節點合併
為瞭解實際效果,打開regTrees.py
並輸入程序清單9-3的代碼。
程序清單9-3 回歸樹剪枝函數
def isTree(obj):
return (type(obj).__name__==\'dict\')
def getMean(tree):
if isTree(tree[\'right\']): tree[\'right\'] = getMean(tree[\'right\'])
if isTree(tree[\'left\']): tree[\'left\'] = getMean(tree[\'left\'])
return (tree[\'left\']+tree[\'right\'])/2.0
def prune(tree, testData):
#❶ 沒有測試數據則對樹進行塌陷處理
if shape(testData)[0] == 0: return getMean(tree)
if (isTree(tree[\'right\']) or isTree(tree[\'left\'])):
lSet, rSet = binSplitDataSet(testData, tree[\'spInd\'],tree[\'spVal\'])
if isTree(tree[\'left\']): tree[\'left\'] = prune(tree[\'left\'], lSet)
if isTree(tree[\'right\']): tree[\'right\'] = prune(tree[\'right\'], rSet)
if not isTree(tree[\'left\']) and not isTree(tree[\'right\']):
lSet, rSet = binSplitDataSet(testData, tree[\'spInd\'],tree[\'spVal\'])
errorNoMerge = sum(power(lSet[:,-1] - tree[\'left\'],2)) +sum(power(rSet[:,-1] - tree[\'right\'],2))
treeMean = (tree[\'left\']+tree[\'right\'])/2.0
errorMerge = sum(power(testData[:,-1] - treeMean,2))
if errorMerge < errorNoMerge:
print \"merging\"
return treeMean
else: return tree
else: return tree
程序清單9-3中包含三個函數:isTree
、getMean
和prune
。其中isTree
用於測試輸入變量是否是一棵樹,返回布爾類型的結果。換句話說,該函數用於判斷當前處理的節點是否是葉節點。
函數getMean
是一個遞歸函數,它從上往下遍歷樹直到葉節點為止。如果找到兩個葉節點則計算它們的平均值。該函數對樹進行塌陷處理(即返回樹平均值),在prune
函數中調用該函數時應明確這一點。
程序清單9-3的主函數是prune
,它有兩個參數:待剪枝的樹與剪枝所需的測試數據testData
。prune
函數首先需要確認測試集是否為空❶。一旦非空,則反覆遞歸調用函數prune
對測試數據進行切分。因為樹是由其他數據集(訓練集)生成的,所以測試集上會有一些樣本與原數據集樣本的取值範圍不同。一旦出現這種情況應當怎麼辦?數據發生過擬合應該進行剪枝嗎?或者模型正確不需要任何剪枝?這裡假設發生了過擬合,從而對樹進行剪枝。
接下來要檢查某個分支到底是子樹還是節點。如果是子樹,就調用函數prune
來對該子樹進行剪枝。在對左右兩個分支完成剪枝之後,還需要檢查它們是否仍然還是子樹。如果兩個分支已經不再是子樹,那麼就可以進行合併。具體做法是對合併前後的誤差進行比較。如果合併後的誤差比不合並的誤差小就進行合併操作,反之則不合並直接返回。
接下來看看實際效果,將程序清單9-3的代碼添加到regTrees.py
文件並保存,在Python提示符下輸入下面的命令:
>>> reload(regTrees)
<module \'regTrees\' from \'regTrees.pyc\'>
為了創建所有可能中最大的樹,輸入如下命令:
>>> myTree=regTrees.createTree(myMat2, ops=(0,1))
輸入以下命令導入測試數據:
>>> myDatTest=regTrees.loadDataSet(\'ex2test.txt\')
>>> myMat2Test=mat(myDatTest)
輸入以下命令,執行剪枝過程:
>>> regTrees.prune(myTree, myMat2Test)
merging
merging
merging
.
.
merging
{\'spInd\': 0, \'spVal\': matrix([[ 0.499171]]), \'right\': {\'spInd\': 0, \'spVal\':
.
.
01, \'left\': {\'spInd\': 0, \'spVal\': matrix([[ 0.960398]]), \'right\': 123.559747,
\'left\': 112.386764}}}, \'left\': 92.523991499999994}}}}
可以看到,大量的節點已經被剪枝掉了,但沒有像預期的那樣剪枝成兩部分,這說明後剪枝可能不如預剪枝有效。一般地,為了尋求最佳模型可以同時使用兩種剪枝技術。
下節將重用部分已有的樹構建代碼來創建一種新的樹。該樹仍採用二元切分,但葉節點不再是簡單的數值,取而代之的是一些線性模型。