UNet解释及Python实现

介绍

在图像分割中,机器必须将图像分割成不同的segments,每个segment代表不同的实体。

UNet解释及Python实现

图像分割示例

正如你在上面看到的,图像如何变成两个部分,一个代表猫,另一个代表背景。图像分割在从自动驾驶汽车到卫星的许多领域都很有用。也许其中最重要的是医学影像。医学图像的微妙之处是相当复杂的。一台能够理解这些细微差别并识别出必要区域的机器,可以对医疗保健产生深远的影响。

卷积神经网络在简单的图像分割问题上取得了不错的效果,但在复杂的图像分割问题上却没有取得任何进展。这就是UNet的作用。UNet最初是专门为医学图像分割而设计的。该方法取得了良好的效果,并在以后的许多领域得到了应用。在本文中,我们将讨论UNet工作的原因和方式

UNet背后的直觉

卷积神经网络(CNN)背后的主要思想是学习图像的特征映射,并利用它进行更细致的特征映射。这在分类问题中很有效,因为图像被转换成一个向量,这个向量用于进一步的分类。但是在图像分割中,我们不仅需要将feature map转换成一个向量,还需要从这个向量重建图像。这是一项巨大的任务,因为要将向量转换成图像比反过来更困难。UNet的整个理念都围绕着这个问题。

在将图像转换为向量的过程中,我们已经学习了图像的特征映射,为什么不使用相同的映射将其再次转换为图像呢?这就是UNet背后的秘诀。用同样的 feature maps,将其用于contraction 来将矢量扩展成segmented image。这将保持图像的结构完整性,这将极大地减少失真。让我们更简单地理解架构。

UNet架构

UNet解释及Python实现

UNet架构

该架构看起来像一个'U'。该体系结构由三部分组成:contraction,bottleneck和expansion 部分。contraction部分由许多contraction块组成。每个块接受一个输入,应用两个3X3的卷积层,然后是一个2X2的最大池化。在每个块之后,核或特征映射的数量会加倍,这样体系结构就可以有效地学习复杂的结构。最底层介于contraction层和expansion 层之间。它使用两个3X3 CNN层,然后是2X2 up convolution层。

这种架构的核心在于expansion 部分。与contraction层类似,它也包含几个expansion 块。每个块将输入传递到两个3X3 CNN层,然后是2X2上采样层。此外,卷积层使用的每个块的feature map数量得到一半,以保持对称性。每次输入也被相应的收缩层的 feature maps所附加。这个动作将确保在contracting 图像时学习到的特征将被用于重建图像。expansion 块的数量与contraction块的数量相同。之后,生成的映射通过另一个3X3 CNN层,feature map的数量等于所需的segment的数量。

UNet中的损失计算

UNet对每个像素使用了一种新颖的损失加权方案,使得分割对象的边缘具有更高的权重。这种损失加权方案帮助U-Net模型以不连续的方式分割生物医学图像中的细胞,以便在binary segmentation map中容易识别单个细胞。

首先,在所得图像上应用pixel-wise softmax,然后是交叉熵损失函数。所以我们将每个像素分类为一个类。我们的想法是,即使在分割中,每个像素都必须存在于某个类别中,我们只需要确保它们可以。因此,我们只是将分段问题转换为多类分类问题,与传统的损失函数相比,它表现得非常好。

UNet实现的Python代码

Python代码如下:

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
class UNet(nn.Module):
 def contracting_block(self, in_channels, out_channels, kernel_size=3):
 block = torch.nn.Sequential(
 torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=out_channels),
 torch.nn.ReLU(),
 torch.nn.BatchNorm2d(out_channels),
 torch.nn.Conv2d(kernel_size=kernel_size, in_channels=out_channels, out_channels=out_channels),
 torch.nn.ReLU(),
 torch.nn.BatchNorm2d(out_channels),
 )
 return block
 
 def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
 block = torch.nn.Sequential(
 torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
 torch.nn.ReLU(),
 torch.nn.BatchNorm2d(mid_channel),
 torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
 torch.nn.ReLU(),
 torch.nn.BatchNorm2d(mid_channel),
 torch.nn.ConvTranspose2d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
 )
 return block
 
 def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
 block = torch.nn.Sequential(
 torch.nn.Conv2d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel),
 torch.nn.ReLU(),
 torch.nn.BatchNorm2d(mid_channel),
 torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel),
 torch.nn.ReLU(),
 torch.nn.BatchNorm2d(mid_channel),
 torch.nn.Conv2d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),
 torch.nn.ReLU(),
 torch.nn.BatchNorm2d(out_channels),
 )
 return block
 
 def __init__(self, in_channel, out_channel):
 super(UNet, self).__init__()
 #Encode
 self.conv_encode1 = self.contracting_block(in_channels=in_channel, out_channels=64)
 self.conv_maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
 self.conv_encode2 = self.contracting_block(64, 128)
 self.conv_maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
 self.conv_encode3 = self.contracting_block(128, 256)
 self.conv_maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
 # Bottleneck
 self.bottleneck = torch.nn.Sequential(
 torch.nn.Conv2d(kernel_size=3, in_channels=256, out_channels=512),
 torch.nn.ReLU(),
 torch.nn.BatchNorm2d(512),
 torch.nn.Conv2d(kernel_size=3, in_channels=512, out_channels=512),
 torch.nn.ReLU(),
 torch.nn.BatchNorm2d(512),
 torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, output_padding=1)
 )
 # Decode
 self.conv_decode3 = self.expansive_block(512, 256, 128)
 self.conv_decode2 = self.expansive_block(256, 128, 64)
 self.final_layer = self.final_block(128, 64, out_channel)
 
 def crop_and_concat(self, upsampled, bypass, crop=False):
 if crop:
 c = (bypass.size()[2] - upsampled.size()[2]) // 2
 bypass = F.pad(bypass, (-c, -c, -c, -c))
 return torch.cat((upsampled, bypass), 1)
 
 def forward(self, x):
 # Encode
 encode_block1 = self.conv_encode1(x)
 encode_pool1 = self.conv_maxpool1(encode_block1)
 encode_block2 = self.conv_encode2(encode_pool1)
 encode_pool2 = self.conv_maxpool2(encode_block2)
 encode_block3 = self.conv_encode3(encode_pool2)
 encode_pool3 = self.conv_maxpool3(encode_block3)
 # Bottleneck
 bottleneck1 = self.bottleneck(encode_pool3)
 # Decode
 decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=True)
 cat_layer2 = self.conv_decode3(decode_block3)
 decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=True)
 cat_layer1 = self.conv_decode2(decode_block2)
 decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=True)
 final_layer = self.final_layer(decode_block1)
 return final_layer

UNet解释及Python实现

UNet解释及Python实现

以上Python代码中的UNet模块代表了UNet的整体架构。使用contracaction_block和expansive_block分别创建contraction部分和expansion部分。crop_and_concat函数的作用是将contraction层的输出添加到新的expansion层输入中。训练部分的Python代码可以写成

unet = Unet(in_channel=1,out_channel=2)
#out_channel represents number of segments desired
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(unet.parameters(), lr = 0.01, momentum=0.99)
optimizer.zero_grad() 
outputs = unet(inputs)
# permute such that number of desired segments would be on 4th dimension
outputs = outputs.permute(0, 2, 3, 1)
m = outputs.shape[0]
# Resizing the outputs and label to caculate pixel wise softmax loss
outputs = outputs.resize(m*width_out*height_out, 2)
labels = labels.resize(m*width_out*height_out)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

UNet解释及Python实现

结论

图像分割是一个重要的问题,每天都有一些新的研究论文发表。UNet在这类研究中做出了重大贡献。许多新架构的灵感都来自UNet。在业界,这种体系结构有很多变体,因此有必要理解第一个变体,以便更好地理解它们。

相关推荐