要對數據的複雜關係建模,我們已經決定借用樹結構來幫助切分數據,那麼如何實現數據的切分呢?怎麼才能知道是否已經充分切分呢?這些問題的答案取決於葉節點的建模方式。回歸樹假設葉節點是常數值,這種策略認為數據中的複雜關係可以用樹結構來概括。
為成功構建以分段常數為葉節點的樹,需要度量出數據的一致性。第3章使用樹進行分類,會在給定節點時計算數據的混亂度。那麼如何計算連續型數值的混亂度呢?事實上,在數據集上計算混亂度是非常簡單的。首先計算所有數據的均值,然後計算每條數據的值到均值的差值。為了對正負差值同等看待,一般使用絕對值或平方值來代替上述差值。上述做法有點類似於前面介紹過的統計學中常用的方差計算。唯一的不同就是,方差是平方誤差的均值(均方差),而這裡需要的是平方誤差的總值(總方差)。總方差可以通過均方差乘以數據集中樣本點的個數來得到。
有了上述誤差計算準則和上一節中的樹構建算法,下面就可以開始構建數據集上的回歸樹了。
9.3.1 構建樹
構建回歸樹,需要補充一些新的代碼,使程序清單9-1中的函數createTree
得以運轉。首先要做的就是實現chooseBestSplit
函數。給定某個誤差計算方法,該函數會找到數據集上最佳的二元切分方式。另外,該函數還要確定什麼時候停止切分,一旦停止切分會生成一個葉節點。因此,函數chooseBestSplit
只需完成兩件事:用最佳方式切分數據集和生成相應的葉節點。
從程序清單9-1可以看出,除了數據集以外,函數chooseBestSplit
還有leafType
、errType
和ops
這三個參數。其中leafType
是對創建葉節點的函數的引用,errType
是對前面介紹的總方差計算函數的引用,而ops
是一個用戶定義的參數構成的元組,用以完成樹的構建。
下面的代碼中,函數chooseBestSplit
最複雜,該函數的目標是找到數據集切分的最佳位置。它遍歷所有的特徵及其可能的取值來找到使誤差最小化的切分閾值。該函數的偽代碼大致如下:
對每個特徵: 對每個特徵值: 將數據集切分成兩份 計算切分的誤差 如果當前誤差小於當前最小誤差,那麼將當前切分設定為最佳切分並更新最小誤差 返回最佳切分的特徵和閾值
下面給出上述三個函數的具體實現代碼。打開regTrees.py
文件並加入程序清單9-2中的代碼。
程序清單9-2 回歸樹的切分函數
def regLeaf(dataSet):
return mean(dataSet[:,-1])
def regErr(dataSet):
return var(dataSet[:,-1]) * shape(dataSet)[0]
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
tolS = ops[0]; tolN = ops[1]
#❶(以下兩行) 如果所有值相等則退出
if len(set(dataSet[:,-1].T.tolist[0])) == 1:
return None, leafType(dataSet)
m,n = shape(dataSet)
S = errType(dataSet)
bestS = inf; bestIndex = 0; bestValue = 0
for featIndex in range(n-1):
for splitVal in set(dataSet[:,featIndex]):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
#❷(以下兩行)如果誤差減少不大則退出
if (S - bestS) < tolS:
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
#❸(以下兩行)如果切分出的數據集很小則退出
return None, leafType(dataSet)
return bestIndex,bestValue
上述程序清單中的第一個函數是regLeaf
,它負責生成葉節點。當chooseBestSplit
函數確定不再對數據進行切分時,將調用該regLeaf
函數來得到葉節點的模型。在回歸樹中,該模型其實就是目標變量的均值。
第二個函數是誤差估計函數regErr
。該函數在給定數據上計算目標變量的平方誤差。當然也可以先計算出均值,然後計算每個差值再平方。但這裡直接調用均方差函數var
更加方便。因為這裡需要返回的是總方差,所以要用均方差乘以數據集中樣本的個數。
第三個函數是chooseBestSplit
,它是回歸樹構建的核心函數。該函數的目的是找到數據的最佳二元切分方式。如果找不到一個「好」的二元切分,該函數返回 None
並同時調用createTree
方法來產生葉節點,葉節點的值也將返回None
。接下來將會看到,在函數chooseBestSplit
中有三種情況不會切分,而是直接創建葉節點。如果找到了一個「好」的切分方式,則返回特徵編號和切分特徵值。
函數chooseBestSplit
一開始為ops
設定了tolS
和tolN
這兩個值。它們是用戶指定的參數,用於控制函數的停止時機。其中變量tolS
是容許的誤差下降值,tolN
是切分的最少樣本數。接下來通過對當前所有目標變量建立一個集合,函數chooseBestSplit
會統計不同剩餘特徵值的數目。如果該數目為1,那麼就不需要再切分而直接返回❶。然後函數計算了當前數據集的大小和誤差。該誤差S
將用於與新切分誤差進行對比,來檢查新切分能否降低誤差。下面很快就會看到這一點。
這樣,用於找到最佳切分的幾個變量就被建立和初始化了。下面就將在所有可能的特徵及其可能取值上遍歷,找到最佳的切分方式。最佳切分也就是使得切分後能達到最低誤差的切分。如果切分數據集後效果提升不夠大,那麼就不應進行切分操作而直接創建葉節點❷。另外還需要檢查兩個切分後的子集大小,如果某個子集的大小小於用戶定義的參數tolN
,那麼也不應切分。最後,如果這些提前終止條件都不滿足,那麼就返回切分特徵和特徵值❸。
9.3.2 運行代碼
下面在一些數據上看看上節代碼的實際效果,以圖9-1的數據為例,我們的目標是從該數據生成一棵回歸樹。
將程序清單9-2中的代碼添加到regTree.py
文件並保存,然後在Python提示符下輸入:
>>>reload(regTrees)
<module \'regTrees\' from \'regTrees.pyc\'>
>>> from numpy import *
圖9-1的數據存儲在文件ex00.txt
中。
>>> myDat=regTrees.loadDataSet(\'ex00.txt\')
>>> myMat = mat(myDat)
>>> regTrees.createTree(myMat)
{\'spInd\': 0, \'spVal\': matrix([[ 0.48813]]),
\'right\': -0.044650285714285733,
\'left\': 1.018096767241379}
圖9-1 基於CART算法構建回歸樹的簡單數據集
再看一個多次切分的例子,考慮圖9-2的數據集。
圖9-2 用於測試回歸樹的分段常數數據集
圖9-2的數據保存在一個以tab鍵分隔的文本文檔ex0.txt
中數據。為從上述數據中構建一棵回歸樹,在Python提示符下敲入如下命令:
>>> myDat1=regTrees.loadDataSet(\'ex0.txt\') >>> myMat1=mat(myDat1) >>> regTrees.createTree(myMat1) {\'spInd\': 1, \'spVal\': matrix([[ 0.39435]]), \'right\': {\'spInd\': 1, \'spVal\': matrix([[ 0.197834]]), \'right\': -0.023838155555555553, \'left\': 1.0289583666666664}, \'left\': {\'spInd\': 1, \'spVal\': matrix([[ 0.582002]]), \'right\': 1.9800350714285717, \'left\': {\'spInd\': 1, \'spVal\': matrix([[ 0.797583]]), \'right\': 2.9836209534883724, \'left\': 3.9871632000000004}}}
可以檢查一下該樹的結構以確保樹中包含5個葉節點。讀者也可以在更複雜的數據集上構建回歸樹並觀察實驗結果。
到現在為止,已經完成回歸樹的構建,但是需要某種措施來檢查構建過程否得當。下面將介紹樹剪枝(tree pruning)技術,它通過對決策樹剪枝來達到更好的預測效果。