使用Python中的广义加性模型构建可解释模型
目的
构建能够产生高精度的机器/深度学习模型正变得越来越容易,但在可解释性方面,大多数模型还远远不够好。在许多情况下,您可能需要更多地强调对模型的理解,而不是准确性
广义加性模型(GAM)是一种功能强大而又简单的技术,它的代表性较差。很少有数据科学家知道它,或者将它应用到日常工作中,特别是在Python中。在本文中,您将了解如何在Python中构建通用加性模型,以及如何使用其部分依赖函数来检查每个特征的贡献。我们使用两个公共数据集来构建两个GAM模型:一个用于分类,另一个用于回归。pyGAM包用于训练GAM。
这篇文章简要解释了GAM背后的理论。
- GAM 101
- 如何使用pyGAM构建GAM
- 证明GAM在回归和分类中的可解释性
GAM 101
要了解如何构建GAM,我们需要了解GAM的结构及其中的一些重要概念。GAM的结构可以写成:
- g(E(Y))是链式函数,它将期望值与预测变量x1,x2,...,xm相关联。它告诉响应的预期值如何与预测变量相关。GAM支持多种链式函数。
- f1(x1) + f2(x2) +…+fm(xm)是一种加性结构的函数形式,它由许多项组成:f1(x1)、f2(x2)、…、fm(xm)。这些术语表示光滑的非参数函数。
- 分布是指响应变量Y的分布。它可以是指数族的任何分布,例如高斯,二项式泊松等。
GAM可对部分或全部的自变量采用平滑函数的方法建立模型,函数可以是非参数的形式,适用于多种分布类型、多种复杂非线性关系的分析。部分依赖图用于证明部分关系。
如何使用pyGAM构建GAM
pyGAM是一个用于在Python中构建GAM的包。据我所知,它可能是唯一可用于GAM的Python包。pyGAM在pypi上,可以使用pip安装
pip install pygam
要使用pyGAM训练GAM,我们需要指定链式函数,函数形式和分布如下:
from pygam import GAM, s, f gam = GAM(s(0, n_splines=5) + s(1) + f(2) + s(3), distribution=’gamma’, link=’log’)
pyGAM还具有内置的通用模型,可以轻松创建GAM。常见的模型有LinearGAM,LogisticGAM,PoissonGAM,GammaGAM,InvGuss。模型训练简化为:
from pygam import PoissonGAM gam = PoissonGAM(s(0, n_splines=5) + s(1) + f(2) + s(3))
使用`gridsearch()`自动调整模型
找到最佳模型需要调整几个关键参数,包括n_splines,lam和constraints。其中,lam对GAM的性能非常重要。它控制着每个项的正则化惩罚的强度。pyGAM构建了一个网格搜索功能,可以构建一个网格来搜索多个lam值,以便具有最低广义交叉验证(GCV)得分的模型。
部分依赖图
pyGAM支持matplotlib的部分依赖图。GAM中每个项的部分依赖性可以用估计函数的95%置信区间来可视化。
构建可解释的GAM
回归
这个机器学习数据集是关于葡萄牙“Vinho Verde”葡萄酒的红色变种,可从UCI机器学习库获得。输入特征是11个物理化学变量,描述了来自各个方面的红葡萄酒变体。目标特征是质量分数,范围从0到10,这表明红酒有多好。
数据集概述
准备数据
redwine_X = redwine.drop(['quality'], axis=1).values redwine_y = redwine['quality']
通过gridsearch构建模型
lams = np.random.rand(100, 11) lams = lams * 11 - 3 lams = np.exp(lams) print(lams.shape) gam = LinearGAM(n_splines=10).gridsearch(redwine_X, redwine_y, lam=lams)
部分依赖图
titles = redwine.columns[0:11] plt.figure() fig, axs = plt.subplots(1,11,figsize=(40, 8)) for i, ax in enumerate(axs): XX = gam.generate_X_grid(term=i) ax.plot(XX[:, i], gam.partial_dependence(term=i, X=XX)) ax.plot(XX[:, i], gam.partial_dependence(term=i, X=XX, width=.95)[1], c='r', ls='--') if i == 0: ax.set_ylim(-30,30) ax.set_title(titles[i])
部分依赖图显示影响红葡萄酒质量的特征
到目前为止,我们已经建立了一个线性GAM,可以根据物理化学变量预测红葡萄酒质量得分。更重要的是,在上述部分依赖图中揭示了这些物理化学变量中的每一个如何影响质量得分。如上所示,volatile acidity, chlorides, total sulfur dioxide, density,和pH与质量得分呈负相关,意味着值越高,质量得分越低。另一方面,随着residual sugar和free sulfur dioxide的值变大,质量得分增加。我们还注意到固定酸度对质量评分有影响。柠檬酸,硫酸盐,酒精的影响更为复杂。例如,最佳酒精含量约为13.值高于或低于该值会降低质量得分。
分类
该机器学习数据集(https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/)包含30个特征,描述乳腺肿块图像中细胞核的特征。它们由细胞核的10个描述符计算得到,包括半径、纹理、周长等。此外,每个记录都被标记为malignant (M)或bengn (B)。
机器学习数据集概述
准备数据
# drop the id column tumors = tumors.drop(['id'],axis=1) # encode the diagnosis column tumors.loc[tumors['diagnosis']=='M','diagnosis'] =1 tumors.loc[tumors['diagnosis']=='B','diagnosis'] =0 tumors_X = tumors.iloc[:,:11].drop(['diagnosis'], axis=1).values tumors_y = tumors['diagnosis']
使用LogisticGAM构建模型
log_gam = LogisticGAM(n_splines=10).gridsearch(tumors_X, tumors_y)
检查训练模型的准确性
log_gam.accuracy(tumors_X, tumors_y) 0.9578207381370826
部分依赖图
titles = tumors.columns[1:11] plt.figure() fig, axs = plt.subplots(1,10,figsize=(40, 8)) for i, ax in enumerate(axs): XX = log_gam.generate_X_grid(term=i) ax.plot(XX[:, i], log_gam.partial_dependence(term=i, X=XX)) ax.plot(XX[:, i], log_gam.partial_dependence(term=i, X=XX, width=.95)[1], c='r', ls='--') if i == 0: ax.set_ylim(-30,30) ax.set_title(titles[i])
乳房肿块图像分类的部分依赖图
部分依赖图揭示了GAM模型的可解释性。与响应变量正相关的变量包括:平均半径,平均纹理,平均面积,平均平滑度,平均凹点和平均对称性。值越高,恶性越大。平均周长越大意味着恶化的可能性越小。
结论
作为数据科学家,您应该将GAM添加到您的库中。它在可解释性方面的优势在许多情况下都非常有用。希望本文能帮助您了解技术并在您的工作中尝试。