讀古今文學網 > 機器學習實戰 > 9.3 將CART算法用於回歸 >

9.3 將CART算法用於回歸

要對數據的複雜關係建模,我們已經決定借用樹結構來幫助切分數據,那麼如何實現數據的切分呢?怎麼才能知道是否已經充分切分呢?這些問題的答案取決於葉節點的建模方式。回歸樹假設葉節點是常數值,這種策略認為數據中的複雜關係可以用樹結構來概括。

為成功構建以分段常數為葉節點的樹,需要度量出數據的一致性。第3章使用樹進行分類,會在給定節點時計算數據的混亂度。那麼如何計算連續型數值的混亂度呢?事實上,在數據集上計算混亂度是非常簡單的。首先計算所有數據的均值,然後計算每條數據的值到均值的差值。為了對正負差值同等看待,一般使用絕對值或平方值來代替上述差值。上述做法有點類似於前面介紹過的統計學中常用的方差計算。唯一的不同就是,方差是平方誤差的均值(均方差),而這裡需要的是平方誤差的總值(總方差)。總方差可以通過均方差乘以數據集中樣本點的個數來得到。

有了上述誤差計算準則和上一節中的樹構建算法,下面就可以開始構建數據集上的回歸樹了。

9.3.1 構建樹

構建回歸樹,需要補充一些新的代碼,使程序清單9-1中的函數createTree得以運轉。首先要做的就是實現chooseBestSplit函數。給定某個誤差計算方法,該函數會找到數據集上最佳的二元切分方式。另外,該函數還要確定什麼時候停止切分,一旦停止切分會生成一個葉節點。因此,函數chooseBestSplit只需完成兩件事:用最佳方式切分數據集和生成相應的葉節點。

從程序清單9-1可以看出,除了數據集以外,函數chooseBestSplit還有leafTypeerrTypeops這三個參數。其中leafType是對創建葉節點的函數的引用,errType是對前面介紹的總方差計算函數的引用,而ops是一個用戶定義的參數構成的元組,用以完成樹的構建。

下面的代碼中,函數chooseBestSplit最複雜,該函數的目標是找到數據集切分的最佳位置。它遍歷所有的特徵及其可能的取值來找到使誤差最小化的切分閾值。該函數的偽代碼大致如下:

對每個特徵:
    對每個特徵值:
        將數據集切分成兩份
        計算切分的誤差
        如果當前誤差小於當前最小誤差,那麼將當前切分設定為最佳切分並更新最小誤差
返回最佳切分的特徵和閾值
  

下面給出上述三個函數的具體實現代碼。打開regTrees.py文件並加入程序清單9-2中的代碼。

程序清單9-2 回歸樹的切分函數

def regLeaf(dataSet):
    return mean(dataSet[:,-1])

def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0]; tolN = ops[1]
    #❶(以下兩行) 如果所有值相等則退出
    if len(set(dataSet[:,-1].T.tolist[0])) == 1:
        return None, leafType(dataSet)
    m,n = shape(dataSet)
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set(dataSet[:,featIndex]):
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #❷(以下兩行)如果誤差減少不大則退出
    if (S - bestS) < tolS:
        return None, leafType(dataSet)
    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
        #❸(以下兩行)如果切分出的數據集很小則退出  
        return None, leafType(dataSet)
    return bestIndex,bestValue
  

上述程序清單中的第一個函數是regLeaf,它負責生成葉節點。當chooseBestSplit函數確定不再對數據進行切分時,將調用該regLeaf函數來得到葉節點的模型。在回歸樹中,該模型其實就是目標變量的均值。

第二個函數是誤差估計函數regErr。該函數在給定數據上計算目標變量的平方誤差。當然也可以先計算出均值,然後計算每個差值再平方。但這裡直接調用均方差函數var更加方便。因為這裡需要返回的是總方差,所以要用均方差乘以數據集中樣本的個數。

第三個函數是chooseBestSplit,它是回歸樹構建的核心函數。該函數的目的是找到數據的最佳二元切分方式。如果找不到一個「好」的二元切分,該函數返回 None並同時調用createTree方法來產生葉節點,葉節點的值也將返回None。接下來將會看到,在函數chooseBestSplit中有三種情況不會切分,而是直接創建葉節點。如果找到了一個「好」的切分方式,則返回特徵編號和切分特徵值。

函數chooseBestSplit一開始為ops設定了tolStolN這兩個值。它們是用戶指定的參數,用於控制函數的停止時機。其中變量tolS是容許的誤差下降值,tolN是切分的最少樣本數。接下來通過對當前所有目標變量建立一個集合,函數chooseBestSplit會統計不同剩餘特徵值的數目。如果該數目為1,那麼就不需要再切分而直接返回❶。然後函數計算了當前數據集的大小和誤差。該誤差S將用於與新切分誤差進行對比,來檢查新切分能否降低誤差。下面很快就會看到這一點。

這樣,用於找到最佳切分的幾個變量就被建立和初始化了。下面就將在所有可能的特徵及其可能取值上遍歷,找到最佳的切分方式。最佳切分也就是使得切分後能達到最低誤差的切分。如果切分數據集後效果提升不夠大,那麼就不應進行切分操作而直接創建葉節點❷。另外還需要檢查兩個切分後的子集大小,如果某個子集的大小小於用戶定義的參數tolN,那麼也不應切分。最後,如果這些提前終止條件都不滿足,那麼就返回切分特徵和特徵值❸。

9.3.2 運行代碼

下面在一些數據上看看上節代碼的實際效果,以圖9-1的數據為例,我們的目標是從該數據生成一棵回歸樹。

將程序清單9-2中的代碼添加到regTree.py文件並保存,然後在Python提示符下輸入:

>>>reload(regTrees)
<module \'regTrees\' from \'regTrees.pyc\'>
>>> from numpy import *
  

圖9-1的數據存儲在文件ex00.txt中。

>>> myDat=regTrees.loadDataSet(\'ex00.txt\')
>>> myMat = mat(myDat)
>>> regTrees.createTree(myMat)
{\'spInd\': 0, \'spVal\': matrix([[ 0.48813]]),
\'right\': -0.044650285714285733,
\'left\': 1.018096767241379}
  

圖9-1 基於CART算法構建回歸樹的簡單數據集

再看一個多次切分的例子,考慮圖9-2的數據集。

圖9-2 用於測試回歸樹的分段常數數據集

圖9-2的數據保存在一個以tab鍵分隔的文本文檔ex0.txt中數據。為從上述數據中構建一棵回歸樹,在Python提示符下敲入如下命令:

>>> myDat1=regTrees.loadDataSet(\'ex0.txt\')
>>> myMat1=mat(myDat1)
>>> regTrees.createTree(myMat1)
{\'spInd\': 1, \'spVal\': matrix([[ 0.39435]]), \'right\': {\'spInd\': 1, \'spVal\':
matrix([[ 0.197834]]), \'right\': -0.023838155555555553, \'left\':
1.0289583666666664}, \'left\': {\'spInd\': 1, \'spVal\': matrix([[ 0.582002]]),
\'right\': 1.9800350714285717, \'left\': {\'spInd\': 1, \'spVal\': matrix([[
0.797583]]), \'right\': 2.9836209534883724, \'left\': 3.9871632000000004}}}
  

可以檢查一下該樹的結構以確保樹中包含5個葉節點。讀者也可以在更複雜的數據集上構建回歸樹並觀察實驗結果。

到現在為止,已經完成回歸樹的構建,但是需要某種措施來檢查構建過程否得當。下面將介紹樹剪枝(tree pruning)技術,它通過對決策樹剪枝來達到更好的預測效果。