讀古今文學網 > 機器學習實戰 > 6.6 示例:手寫識別問題回顧 >

6.6 示例:手寫識別問題回顧

考慮這樣一個假想的場景。你的老闆過來對你說:「你寫的那個手寫體識別程序非常好,但是它佔用的內存太大了。顧客不能通過無線的方式下載我們的應用(在寫本書時,無線下載的限制容量為10MB,可以肯定,這將來會成為笑料的。)我們必須在保持其性能不變的同時,使用更少的內存。我呢,告訴了CEO,你會在一周內準備好,但你到底還得多長時間才能搞定這件事?」我不確定你到底會如何回答,但是如果想要滿足他們的需求,你可以考慮使用支持向量機。儘管第2章所使用的kNN方法效果不錯,但是需要保留所有的訓練樣本。而對於支持向量機而言,其需要保留的樣本少了很多(即只保留支持向量),但是能獲得可比的效果。

示例:基於SVM的數字識別

  1. 收集數據:提供的文本文件。
  2. 準備數據:基於二值圖像構造向量。
  3. 分析數據:對圖像向量進行目測。
  4. 訓練算法:採用兩種不同的核函數,並對徑向基核函數採用不同的設置來運行SMO算法 。
  5. 測試算法:編寫一個函數來測試不同的核函數並計算錯誤率。
  6. 使用算法:一個圖像識別的完整應用還需要一些圖像處理的知識,這裡並不打算深入介紹。

使用第2章中的一些代碼和SMO算法,可以構建一個系統去測試手寫數字上的分類器。打開svmMLiA.py並將第2章中knn.py中的img2vector函數複製過來。然後,加入程序清單6-9中的代碼。

程序清單6-9 基於SVM的手寫數字識別

def loadImages(dirName):
    from os import listdir
    hwLabels = 
    trainingFileList = listdir(dirName)
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split(\'.\')[0]
        classNumStr = int(fileStr.split(\'_\')[0])
        if classNumStr == 9: hwLabels.append(-1)
        else: hwLabels.append(1)
        trainingMat[i,:] = img2vector(\'%s/%s\' % (dirName, fileNameStr))
    return trainingMat, hwLabels
def testDigits(kTup=(\'rbf\', 10)):
    dataArr,labelArr = loadImages(\'trainingDigits\')
    b,alphas = smoP(dataArr, labelArr, 200, 0.0001, 10000, kTup)
    datMat=mat(dataArr); labelMat = mat(labelArr).transpose
    svInd=nonzero(alphas.A>0)[0]
    sVs=datMat[svInd]
    labelSV = labelMat[svInd];
    print \"there are %d Support Vectors\" % shape(sVs)[0]
    m,n = shape(datMat)
    errorCount = 0
    for i in range(m):
        kernelEval = kernelTrans(sVs,datMat[i,:],kTup)
        predict=kernelEval.T * multiply(labelSV,alphas[svInd]) + b
        if sign(predict)!=sign(labelArr[i]): errorCount += 1
    print \"the training error rate is: %f\" % (float(errorCount)/m)
    dataArr,labelArr = loadImages(\'testDigits\')
    errorCount = 0
    datMat=mat(dataArr); labelMat = mat(labelArr).transpose
    m,n = shape(datMat)
    for i in range(m):
        kernelEval = kernelTrans(sVs,datMat[i,:],kTup)
        predict=kernelEval.T * multiply(labelSV,alphas[svInd]) + b
        if sign(predict)!=sign(labelArr[i]): errorCount += 1
print \"the test error rate is: %f\" % (float(errorCount)/m)
  

函數loadImages是作為前面kNN.py中的handwritingClassTest的一部分出現的。它已經被重構為自身的一個函數。其中僅有的一個大區別在於,在kNN.py中代碼直接應用類別標籤,而同支持向量機一起使用時,類別標籤為-1或者+1。因此,一旦碰到數字9,則輸出類別標籤-1,否則輸出+1。本質上,支持向量機是一個二類分類器,其分類結果不是+1就是-1。基於SVM構建多類分類器已有很多研究和對比了,如果讀者感興趣,建議閱讀C. W. Huset等人發表的一篇論文「A Comparison of Methods for Multiclass Support Vector Machines」1。由於這裡我們只做二類分類,因此除了1和9之外的數字都被去掉了。

1. C. W. Hus, and C. J. Lin, 「A Comparison of Methods for Multiclass Support Vector Machines,」 IEEE Transactions on Neural Networks 13, no. 2 (March 2002), 415–25.

下一個函數testDigits並不是全新的函數,它和testRbf的代碼幾乎一樣,唯一的大區別就是它調用了loadImages函數來獲得類別標籤和數據。另一個細小的不同是現在這裡的函數元組kTup是輸入參數,而在testRbf中默認的就是使用rbf核函數。如果對於函數testDigits不增加任何輸入參數的話,那麼kTup的默認值就是(\'rbf\' ,10)。

輸入程序清單6-9中的代碼之後,將之保存為svmMLiA.py並輸入如下命令:

>>> svmMLiA.testDigits((\'rbf\', 20))
                    .
                    .
L==H
fullSet, iter: 3 i:401, pairs changed 0
iteration number: 4
there are 43 Support Vectors
the training error rate is: 0.017413
the test error rate is: 0.032258
 

我嘗試了不同的σ值,並嘗試了線性核函數,總結得到的結果如表6-1所示。

表6-1 不同σ值的手寫數字識別性能

內核,設置 訓練錯誤率(%)測試錯誤率(%)支持向量數 RBF, 0.1052402 RBF, 503.2402 RBF, 1000.599 RBF, 500.22.241 RBF, 100 4.5 4.3 26 Linear2.72.238

表6-1給出的結果表明,當徑向基核函數中的參數σ取10左右時,就可以得到最小的測試錯誤率。該參數值比前面例子中的取值大得多,而前面的測試錯誤率在1.3左右。為什麼差距如此之大?原因就在於數據的不同。在手寫識別的數據中,有1024個特徵,而這些特徵的值有可能高達1.0。而在6.5節的例子中,所有數據從-1到1變化,但是只有2個特徵。如何才能知道該怎麼設置呢?說老實話,在寫這個例子時我也不知道。我只是對不同的設置進行了多次嘗試。C的設置也會影響到分類的結果。當然,存在另外的SVM形式,它們把C同時考慮到了優化過程中,例如v-SVM。有關v-SVM的一個較好的討論可以參考本書第3章介紹過的Sergios Theodoridis和Konstantinos Koutroumbas撰寫的Pattern Recognition2。

2 .Sergios Theodoridis and Konstantinos Koutroumbas, Pattern Recognition, 4th ed. (Academic Press, 2009), 133.

你可能注意到了一個有趣的現象,即最小的訓練錯誤率並不對應於最小的支持向量數目。另一個值得注意的就是,線性核函數的效果並不是特別的糟糕。可以以犧牲線性核函數的錯誤率來換取分類速度的提高。儘管這一點在實際中是可以接受的,但是還得取決於具體的應用。