Python中的K-Nearest Neighbor分类
K-nearest neighbor 是用于对数据进行分类的监督机器学习算法之一。
nearest neighbor 方法背后的原理是找到已经确定的多个训练数据并计算出离新点最近的距离,并预测新数据的标签。样本的数量可以是用户确定的常数(k-nearest neighbor learning)。
KNN与python
可以使用python sklearn-nearest neighbor库运行KNN算法。
导入库
import pandas as pd from sklearn import preprocessing
数据集:
本实验中使用的数据使用caravan.csv数据。数据包含有关保险公司客户的信息。该数据由86个变量组成,包括变量使用数据和社会人口数据。根据这些数据,我们将预测是否有人会购买保险。
#Import data df=pd.read_csv('D:/dataset/Caravan.csv') df.head()
划分数据X属性和y(标签)
数据X是非标签数据,而数据y是标签数据。除此之外,还使用MinMaxScaler()完成数据规范化。
x=df.drop("Purchase",axis=1) y=df["Purchase"] #data Dinormalisasi min_max_scaler = preprocessing.MinMaxScaler() #Normalisasi data X x_scaled = min_max_scaler.fit_transform(x) #dibuat data frame df_xscaled = pd.DataFrame(x_scaled)
划分训练数据和测试数据
在此阶段,数据共享用于训练和测试。比例为80:20。因此,在训练数据中有80%的数据和测试有20%的数据,这是针对标准化数据。在此阶段使用train_test_split函数。
from sklearn.model_selection import train_test_split #Untuk data tidak di normalisasi x_train, x_test = train_test_split(x, test_size=0.2) y_train, y_test = train_test_split(y, test_size=0.2) #Untuk data ternormalisasi cukup x saja. x_scaled_train, x_scaled_test = train_test_split(df_xscaled, test_size=0.2)
用KNN建模
建模分为两部分,第一部分建模用于非标准化数据。在该建模中,执行10次迭代以查看10k-nearest neighbor可能性的准确性。使用的库是nearest_neighbor。精度计算使用sklearn.metrics.accuracy_score。
建模代码1:
from sklearn.neighbors import KNeighborsClassifier from sklearn.metrics import accuracy_score array_hasil=[] for i in range(1,11): knn = KNeighborsClassifier(n_neighbors=i) knn.fit(x_train,y_train) #masukan prediksi pred=knn.predict(x_test) #nilaiPrediksi hasil=accuracy_score(y_test, pred) array_hasil.append(hasil)
建模代码2:
array_norm=[] for i in range(1,11): knn = KNeighborsClassifier(n_neighbors=i) knn.fit(x_scaled_train,y_train) #masukan prediksi pred1=knn.predict(x_scaled_test) #nilaiPrediksi hasil1=accuracy_score(y_test, pred1) array_norm.append(hasil1)
准确性和可视化
前10个建模迭代的准确性:
import matplotlib.pyplot as plt import numpy as np plt.plot(array_hasil) plt.ylabel('nilai akurasi') plt.xlabel('nilai K') plt.xticks(np.arange(10),('1','2','3','4','5','6','7','8','9','10')) plt.show()
建模结果1:
第10次建模迭代的准确性:
import matplotlib.pyplot as plt plt.plot(array_norm) plt.ylabel('nilai akurasi') plt.xlabel('nilai K') plt.xticks(np.arange(10),('1','2','3','4','5','6','7','8','9','10')) plt.show()
结论
从两个模型可以看出,精度值在k = 4时达到稳定点。具有标准化数据的模型比没有标准化的数据更好的准确度值。