机器学习:在PyTorch中实现Grad-CAM
Deep Learning With Python一书中描述了VGG16网络的类激活映射的实现,使用Keras实现了算法。本文将使用PyTorch重新实现CAM算法。
Grad-CAM
该算法本身来源于这个论文。这是对计算机视觉分析工具的一个很好的补充。它为我们提供了一种方法来研究图像的哪些特定部分影响(influenced)了整个模型对特定指定标签的决策。它在分析错误分类的样本时特别有用。该算法直观、易于实现。
算法背后的直觉是基于这样一个事实,即深度学习模型必须已经看到了一些像素(或图像的区域),并决定图像中出现了什么对象。 可以用梯度来描述数学术语中的影响(Influence)。 在高层次上,这就是算法的作用。 它首先找到相对于深度模型中最新激活映射的最主要logit的梯度。 我们可以将其解释为最终激活映射中最终激活的一些编码特征说服整个机器学习模型选择特定的logit(随后是相应的类)。 然后通过channel-wise池化梯度,并且用相应的梯度对激活通道(channels)进行加权,产生加权激活通道的集合。 通过检查这些通道,我们可以判断哪些通道在类决策中发挥了最重要的作用。
在这篇文章中,我将使用PyTorch重新实现Grad-CAM算法,我将使用不同的架构。
VGG19
在这部分中,我将尝试使用一个非常相似的深度学习模型 - VGG19重现Chollet的结果。我实现的主要思想是解剖网络,这样我们就可以得到最后一层卷积层的激活。Keras通过Keras函数有一种非常直接的方法。然而,在PyTorch我必须跳过一些次要的步骤。
该策略的定义如下:
- 加载VGG19模型
- 找到它的最后一个卷积层
- 计算最可能的类
- 对我们刚刚得到的激活映射使用logit类的梯度
- 池化梯度
- 通过相应的池化梯度对映射的通道进行加权
- 插值热图
我从ImageNet数据集中提取了一些图像(包括Chollet在书(Deep Learning With Python)中使用的大象图像)来研究这个算法。我还将Grad-CAM应用于Facebook上的一些照片,以了解该算法在“field”条件下的工作原理。以下是我们将要使用的原始图像:
左图:Chollet在他的书中使用的大象形象。中图和右图:来自ImageNet的白鲨图像
来自ImageNet数据集的鬣蜥图像
左:应用YOLO模型。右图:在莫斯科乘坐火车
让我们从torchvision模块加载VGG19模型并准备变换和数据加载器,Python代码如下:
import torch import torch.nn as nn from torch.utils import data from torchvision.models import vgg19 from torchvision import transforms from torchvision import datasets import matplotlib.pyplot as plt import numpy as np # use the ImageNet transformation transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) # define a 1 image dataset dataset = datasets.ImageFolder(root='./data/Elephant/', transform=transform) # define the dataloader to load that single image dataloader = data.DataLoader(dataset=dataset, shuffle=False, batch_size=1)
在这里,我导入了我们用于在PyTorch中使用神经网络的所有标准内容。我使用基本变换来使用在ImageNet数据集上训练的机器学习模型,包括图像归一化。我将一次提供一个图像,因此我将我的数据集定义为大象的图像,以获得与书中类似的结果。
这里有一个棘手的部分(但不是太棘手)。 我们可以使用在torch.Tensor上调用的.backward()方法计算PyTorch中的梯度。 我将在最可能的logit上调用backward(),这是通过网络执行图像的forward pass获得的。 但是,PyTorch仅缓存计算图中叶节点的梯度,例如权重,偏差和其他参数。 与激活相关的输出梯度仅仅是中间值,一旦梯度在返回时通过它们传播,就会被丢弃。 那么我们有什么选择呢?
PyTorch中有一个回调函数:hooks。Hooks可以用在不同的场景中。PyTorch文档告诉我们如何将钩子附加到中间值上,以便在丢弃梯度之前将它们从模型中拉出来。文件告诉我们:
每次计算相对于张量的梯度时,都会调用钩子。
现在我们知道我们必须将反向钩子注册到VGG19模型中最后一个卷积层的激活映射。
通过调用VGG19pretrained=True),我们可以很容易地观察VGG19架构:
PyTorch中的预训练机器学习模型大量使用了sequence()模块,在大多数情况下,这使得它们很难被分解,稍后我们将看到它的示例。
在图中,我们看到了整个VGG19架构。我突出显示了feature块中的最后一个卷积层(包括激活函数)。现在我们知道我们想要在网络的特征块的第35层注册backward hook。另外,值得一提的是,有必要将钩子注册到forward()方法中,以避免将钩子注册到duplicate tensor并随后丢失梯度的问题。
正如您所看到的,在feature块中还剩下一个最大池化层,不用担心,我将在forward()方法中添加这个层。Python代码如下:
class VGG(nn.Module): def __init__(self): super(VGG, self).__init__() # get the pretrained VGG19 network self.vgg = vgg19(pretrained=True) # disect the network to access its last convolutional layer self.features_conv = self.vgg.features[:36] # get the max pool of the features stem self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # get the classifier of the vgg19 self.classifier = self.vgg.classifier # placeholder for the gradients self.gradients = None # hook for the gradients of the activations def activations_hook(self, grad): self.gradients = grad def forward(self, x): x = self.features_conv(x) # register the hook h = x.register_hook(self.activations_hook) # apply the remaining pooling x = self.max_pool(x) x = x.view((1, -1)) x = self.classifier(x) return x # method for the gradient extraction def get_activations_gradient(self): return self.gradients # method for the activation exctraction def get_activations(self, x): return self.features_conv(x)
到目前为止,这看起来很棒,我们终于可以从机器学习模型中获得梯度和激活了。
Drawing CAM
首先,让我们用大象的图像在网络中pass through,看看VGG19预测了什么。不要忘记将你的深度学习模型设置为评估模式,否则你会得到非常随机的结果,Python实现如下:
# initialize the VGG model vgg = VGG() # set the evaluation mode vgg.eval() # get the image from the dataloader img, _ = next(iter(dataloader)) # get the most likely prediction of the model pred = vgg(img).argmax(dim=1)
正如预期的那样,我们得到的结果与Chollet在他的书中得到的结果相同:
Predicted: [('n02504458', 'African_elephant', 20.891441), ('n01871265', 'tusker', 18.035757), ('n02504013', 'Indian_elephant', 15.153353)]
现在,我们将使用第386类的logit进行反向传播,该logit代表ImageNet数据集中的“African_elephant”。
# get the gradient of the output with respect to the parameters of the model pred[:, 386].backward() # pull the gradients out of the model gradients = vgg.get_activations_gradient() # pool the gradients across the channels pooled_gradients = torch.mean(gradients, dim=[0, 2, 3]) # get the activations of the last convolutional layer activations = vgg.get_activations(img).detach() # weight the channels by corresponding gradients for i in range(512): activations[:, i, :, :] *= pooled_gradients[i] # average the channels of the activations heatmap = torch.mean(activations, dim=1).squeeze() # relu on top of the heatmap # expression (2) in https://arxiv.org/pdf/1610.02391.pdf heatmap = np.maximum(heatmap, 0) # normalize the heatmap heatmap /= torch.max(heatmap) # draw the heatmap plt.matshow(heatmap.squeeze())
大象图像的热图
最后,我们获得了大象图像的热图。它是一个14x14单通道图像。大小由网络的最后卷积层中的激活映射的空间维度决定。
现在,我们可以使用OpenCV来插入热图并将其投影到原始图像上,这里我使用了Chollet书中的Python代码:
import cv2 img = cv2.imread('./data/Elephant/data/05fig34.jpg') heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) heatmap = np.uint8(255 * heatmap) heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) superimposed_img = heatmap * 0.4 + img cv2.imwrite('./map.jpg', superimposed_img)
在下面的图像中,我们可以看到我们的VGG19网络在决定分配给图像的哪个类('African_elephant')时最重视的图像区域。我们可以假设网络采用大象的头部和耳朵的形状作为图像中存在大象的强烈信号。更有趣的是,该网络还区分了非洲象和塔斯克象和印度象。我不是大象专家,但我认为耳朵和象牙的形状是非常好的区分标准。一般来说,这正是人类如何处理这样的任务。专家将检查耳朵和牙齿的形状,也许还有一些其他微妙的特征可以揭示它是什么类型的大象。
Grad-CAM热图投射到原始大象图像上
好的,让我们用其他一些图像重复相同的过程。
Left: the shark image with projected CAM heat-map
Another shark image with the corresponding CAM hea
鲨鱼主要通过上图图像中的嘴/牙齿区域以及下部图像中的体形和周围水来识别。太酷了!
超越VGG
VGG是一个伟大的架构,因为它提出了更新更高效的图像分类架构。在这一部分中,我们将研究其中一种架构:DenseNet。
在尝试为密集连接网络实施Grad-CAM时遇到了一些问题。首先,正如我已经提到的,PyTorch model zoo中的预训练机器学习模型大多是使用嵌套块构建的。它是可读性和效率的绝佳选择。请注意,VGG由2个块组成:feature block和full connected classifier。DenseNet由多个嵌套块组成,并且试图到达最后一个卷积层的激活映射是不切实际的。我们可以通过两种方式解决此问题:我们可以使用相应的批归一化层获取最后一个激活映射。这将产生非常好的结果,我们很快就会看到。我们要做的第二件事是从头开始构建DenseNet并重新填充块/层的权重,这样我们就可以直接访问这些层。第二种方法似乎太复杂和耗时,所以我避免使用它。
DenseNet CAM的Python代码几乎与用于VGG网络的代码相同,唯一的区别在于层的索引(在DenseNet的情况下为块)我们将从以下方式获得激活:
class DenseNet(nn.Module): def __init__(self): super(DenseNet, self).__init__() # get the pretrained DenseNet201 network self.densenet = densenet201(pretrained=True) # disect the network to access its last convolutional layer self.features_conv = self.densenet.features # add the average global pool self.global_avg_pool = nn.AvgPool2d(kernel_size=7, stride=1) # get the classifier of the vgg19 self.classifier = self.densenet.classifier # placeholder for the gradients self.gradients = None # hook for the gradients of the activations def activations_hook(self, grad): self.gradients = grad def forward(self, x): x = self.features_conv(x) # register the hook h = x.register_hook(self.activations_hook) # don't forget the pooling x = self.global_avg_pool(x) x = x.view((1, 1920)) x = self.classifier(x) return x def get_activations_gradient(self): return self.gradients def get_activations(self, x): return self.features_conv(x)
遵循DenseNet的架构设计非常重要,因此我在分类器之前将全局平均池化添加到网络中。
我将通过密集连接网络传递两个鬣蜥图像,以便找到分配给图像的类:
Predicted: [('n01698640', 'American_alligator', 14.080595), ('n03000684', 'chain_saw', 13.87465), ('n01440764', 'tench', 13.023708)]
在这里,网络预测这是“美国短吻鳄”的形象。嗯,让我们运行我们的Grad-CAM算法来对抗'American Alligator'类。在下面的图像中,我显示了热图和热图在图像上的投影。我们可以看到,网络主要是关注“生物”。很明显鳄鱼看起来像鬣蜥,因为它们都有共同的体形和整体结构。
常见的鬣蜥被误分类为美洲短吻鳄
但是,请注意图像的另一部分影响了类的分数。照片中的摄影师可能会用他的位置和姿势把网络搞乱。模型在做出选择时,兼顾了鬣蜥和人。让我们看看如果我们把拍照者从图像中裁剪出来会发生什么。以下是对裁剪后的图像的前3类预测:
Predicted: [('n01677366', 'common_iguana', 13.84251), ('n01644900', 'tailed_frog', 11.90448), ('n01675722', 'banded_gecko', 10.639269)]
裁切后的鬣蜥图像现在被分类为常见的鬣蜥
我们现在看到,从图像中裁剪人实际上有助于获得图像的正确类别标签。这是Grad-CAM的最佳应用之一:能够获得错误分类图像中可能出错的信息。一旦我们弄清楚可能发生了什么,我们就可以有效地调试机器学习模型。
第二只鬣蜥被正确分类,这里是相应的热图和投影。
第二只鬣蜥通过其背部的尖刺图案来识别
超越ImageNet
让我们尝试一下我从Facebook页面下载的一些图像。我将使用我们的DenseNet201来实现此目的。
抱着猫的形象分类如下:
Predicted: [('n02104365', 'schipperke', 12.584991), ('n02445715', 'skunk', 9.826308), ('n02093256', 'Staffordshire_bullterrier', 8.28862)]
我们来看看这个图像的类激活映射。
在下面的图像中,我们可以看到模型正在寻找正确的位置。
YOLO applied
让我们看看把人去掉是否有助于分类。
Predicted: [('n02123597', 'Siamese_cat', 6.8055286), ('n02124075', 'Egyptian_cat', 6.7294292), ('n07836838', 'chocolate_sauce', 6.4594917)]
现在至少被预测为猫,它更贴近真实的标签。
我们要看的最后一张照片。
图像分类正确:
Predicted: [('n02917067', 'bullet_train', 10.605988), ('n04037443', 'racer', 9.134802), ('n04228054', 'ski', 9.074459)]
我们确实在一辆火车前面。让我们看一下类激活映射。
需要注意的是,DenseNet的最后一层卷积层生成了7x7的空间激活映射(与VGG网络中的14x14相比),因此当将热图投影回图像空间时,热图的分辨率可能有些夸张(对应于脸上的红色)。
另一个可能出现的问题是我们为什么不直接计算logit类关于输入图像的梯度呢。请记住,卷积神经网络作为特征提取器工作,网络的更深层在越来越抽象的空间中运行。我们想知道哪些特征实际影响了模型对类的选择,而不仅仅是单个图像像素。这就是为什么对更深层的卷积层进行激活映射是至关重要的。