讀古今文學網 > 機器學習實戰 > 9.5 模型樹 >

9.5 模型樹

用樹來對數據建模,除了把葉節點簡單地設定為常數值之外,還有一種方法是把葉節點設定為分段線性函數,這裡所謂的分段線性(piecewise linear)是指模型由多個線性片段組成。如果讀者仍不清楚,下面很快就會給出樣例來幫助理解。考慮圖9-4中的數據,如果使用兩條直線擬合是否比使用一組常數來建模好呢 ?答案顯而易見。可以設計兩條分別從0.0~0.3、從0.3~1.0的直線,於是就可以得到兩個線性模型。因為數據集裡的一部分數據(0.0~0.3)以某個線性模型建模,而另一部分數據(0.3~1.0)則以另一個線性模型建模,因此我們說採用了所謂的分段線性模型。

決策樹相比於其他機器學習算法的優勢之一在於結果更易理解。很顯然,兩條直線比很多節點組成一棵大樹更容易解釋。模型樹的可解釋性是它優於回歸樹的特點之一。另外,模型樹也具有更高的預測準確度。

圖9-4 用來測試模型樹構建函數的分段線性數據

前面的代碼稍加修改就可以在葉節點生成線性模型而不是常數值。下面將利用樹生成算法對數據進行切分,且每份切分數據都能很容易被線性模型所表示。該算法的關鍵在於誤差的計算。

前面已經給出了樹構建的代碼,但是這裡仍然需要給出每次切分時用於誤差計算的代碼。不知道讀者是否還記得之前createTree函數里有兩個參數從未改變過。回歸樹把這兩個參數固定,而此處略做修改,從而將前面的代碼重用於模型樹。

下一個問題就是,為了找到最佳切分,應該怎樣計算誤差呢?前面用於回歸樹的誤差計算方法這裡不能再用。稍加變化,對於給定的數據集,應該先用線性的模型來對它進行擬合,然後計算真實的目標值與模型預測值間的差值。最後將這些差值的平方求和就得到了所需的誤差。為瞭解實際效果,打開regTrees.py文件並加入如下代碼。

程序清單9-4 模型樹的葉節點生成函數

def linearSolve(dataSet):
    m,n = shape(dataSet)
    # ❶(以下兩行)將X與Y中的數據格式化
    X = mat(ones((m,n))); Y = mat(ones((m,1)))
    X[:,1:n] = dataSet[:,0:n-1]; Y = dataSet[:,-1]
    xTx = X.T*X
    if linalg.det(xTx) == 0.0:
        raise NameError(\'This matrix is singular, cannot do inverse,n
        try increasing the second value of ops\')
    ws = xTx.I * (X.T * Y)
    return ws,X,Y

def modelLeaf(dataSet):
    ws,X,Y = linearSolve(dataSet)
    return ws

def modelErr(dataSet):
    ws,X,Y = linearSolve(dataSet)
    yHat = X * ws
    return sum(power(Y - yHat, 2))
  

上述程序清單中的第一個函數是linearSolve,它會被其他兩個函數調用。其主要功能是將數據集格式化成目標變量Y和自變量X ❶。與第8章一樣,XY用於執行簡單的線性回歸。另外在這個函數中也應當注意,如果矩陣的逆不存在也會造成程序異常。

第二個函數modelLeaf與程序清單9-2里的函數regLeaf類似,當數據不再需要切分的時候它負責生成葉節點的模型。該函數在數據集上調用linearSolve並返回回歸係數ws

最後一個函數是modelErr,可以在給定的數據集上計算誤差。它與程序清單9-2的函數regErr類似,會被chooseBestSplit調用來找到最佳的切分。該函數在數據集上調用linearSolve,之後返回yHatY之間的平方誤差。

至此,使用程序清單9-1和9-2中的函數構建模型樹的全部代碼已經完成。為瞭解實際效果,保存regTrees.py文件並在Python提示符下輸入:

>>> reload(regTrees)
<module \'regTrees\' from \'regTrees.pyc\'>  
  

圖9-4的數據已保存在一個用tab鍵為分隔符的文本文件exp2.txt裡。

>>> myMat2 = mat(regTrees.loadDataSet(\'exp2.txt\'))
  

為了調用函數createTree和模型樹的函數,需將模型樹函數作為createTree的參數,輸入下面的命令:

>>> regTrees.createTree(myMat2, regTrees.modelLeaf, regTrees.modelErr,(1,10))
{\'spInd\': 0, \'spVal\': matrix([[ 0.285477]]), \'right\': matrix([[3.46877936], [ 1.18521743]]), \'left\': matrix([[ 1.69855694e-03],
[ 1.19647739e+01]])}  
  

可以看到,該代碼以0.285 477為界創建了兩個模型,而圖9-4的數據實際在0.3處分段。createTree生成的這兩個線性模型分別是 y = 3.468 + 1.1852y = 0.001 698 5 + 11.964 77x,與用於生成該數據的真實模型非常接近。該數據實際是由模型 y = 3.5 + 1.0xy = 0 + 12x 再加上高斯噪聲生成的。在圖9-5上可以看到圖9-4的數據以及生成的線性模型。

圖9-5 在圖9-4數據集上應用模型樹算法得到的結果

模型樹、回歸樹以及第8章裡的其他模型,哪一種模型更好呢?一個比較客觀的方法是計算相關係數,也稱為R2值。該相關係數可以通過調用NumPy庫中的命令corrcoef(yHat, y, rowvar=0)來求解,其中yHat是預測值,y是目標變量的實際值。

前一章使用了標準的線性回歸法,本章則使用了樹回歸法,下面將通過實例對二者進行比較,最後用函數corrcoef來分析哪個模型是最優的。