機器學習初學者,超級小白,不對的地方盡請批評指正。歡迎一起探討。
K-nearest Neighbors 學習方法是基于實例的,可用于逼近實值或離散目標函數,概念簡明。對于基于實例的算法,學習過程只是簡單地存儲已知的訓練數據,當遇到新的查詢實例時,一系列相似的實例被從存儲器中取出,并用來分類新的查詢實例。因此,基于實例的算法的最大不足也就在于分類新實例的開銷特別大。
關于該算法的基本介紹可以參考下教材或是維基百科k-nearest neighbor algorithm。這里主要寫一下比較重要的問題。
對于K-nearest Neighbors算法而言,其距離是根據標準歐式距離定義的。可以把實例看做為一個多維向量,其距離就是求向量間的距離。
1NN:預測值或類別,僅根據訓練集中離待預測實例最近的參考實例決定
KNN:首先找到與待測實例最近的k個點,然后根據這k個點決定。進行分類:選擇這k個實例中最普遍的類別值(majority vote);進行回歸(求值):加權平均值(average weighted by inverse distance)。
基本過程:
1 Calculate distances of all training vectors to test vector
2 Pick k closest vectors
3 Calculate average/majority
雖然,KNN算法的原理很簡單,但是其中很多問題需要解決。比如k值如何選擇(k值過小,比較局限不穩定;k值過大,很多噪點影響),如何選擇維度(實例中可能有很多維度屬性與分類無關,而這些維度卻很大程度影響了距離的計算結果),如何規格化參數(比如一個實例向量<1,1000,5>,該向量的第二個屬性影響因子太大,因為我們一般認為所有屬性是同等重要的,因此需要規格化樣本數據),如何建立高效的索引(避免每次分類計算開銷過大)。。。其實,需要研究的問題很多,也很困難。
一個使用K-nearest Neighbors 算法進行分類的應用實例:
進行手寫數字的分類
數據集:Training dataset:http://archive.ics.uci.edu/ml/machine-learning-databases/optdigits/optdigits.tra
Test dataset:http://archive.ics.uci.edu/ml/machine-learning-databases/optdigits/optdigits.tes
說明:
這是一個分類的問題,手寫識別數字,類別為數字0~9。
每個訓練實例,本應該是一個32*32大小的0、1矩陣,但是由于維數過大,有人對此進行了優化,即按4*4大小將其分塊,然后化簡為8*8大小的矩陣,并用一個32維的向量進行表示。
代碼如下:
from numpy import * import operator def file2matrix (dataset_filename) : dataset = open(dataset_filename , 'r') items = dataset.readlines() dimension = len(items[0].split(',')) - 1 train_items_lines = len(items) returnMat = zeros((train_items_lines , dimension)) index = 0 classLabelVector = [] for item in items : item = item.strip() split_item_list = item.split(',') split_item_list = map(lambda x:int(x) , split_item_list) returnMat[index,:] = split_item_list[:dimension] classLabelVector.append(int(split_item_list[-1])) index += 1 dataset.close() return returnMat , classLabelVector def classify(inX , dataset , labels , k) : datasetSize = dataset.shape[0] #Compute distance using matrix #inX repeats datasetSize rows to be a matrix with the same size of dataset diffMat = tile(inX , (datasetSize,1)) - dataset sqDiffMat = diffMat**2 sqDistance = sqDiffMat.sum(axis=1) distance = sqDistance**0.5 sortedDistIndicies = distance.argsort() classCount = {} for i in range(k) : label = labels[sortedDistIndicies[i]] classCount[label] = classCount.get(label,0) + 1 sortedClassCount = sorted(classCount.iteritems() , key=operator.itemgetter(1) , reverse=True) return sortedClassCount[0][0] def test_classify ( k ) : dataset , labels = file2matrix('dataset/optdigits.tra') test_dataset = open('dataset/optdigits.tes' , 'r') test_items = test_dataset.readlines() success = 0 error = 0 for item in test_items : item = item.strip() split_item_list = item.split(',') split_item_list = map(lambda x:int(x) , split_item_list) classify_res = classify(split_item_list[:-1] , dataset , labels , k) real_res = split_item_list[-1] if classify_res == real_res : success += 1 else : error += 1 print '*'*10 , k ,'*'*10 print 'success\t' , success print 'error\t' , error return float(error)/float((error+success)) if __name__ == '__main__' : print test_classify(1) print test_classify(5) print test_classify(10) print test_classify(15) print test_classify(20) print test_classify(30)
運行結果:不同k值得影響
不難看出,k值并非越大越好,對于該問題而言,k在5的范圍之內似乎是最佳的。
![]() |
不含病毒。www.avast.com |