利用残差神经网络对疟疾细胞图像进行分类
本文描述了构建一个图像分类器的过程和经验教训,该图像分类器能够自动对人类血细胞图像进行分类,从而判断其是否感染了疟疾——因此是一个二元分类。
数据集
该项目的灵感来自美国国家医学图书馆的研发部门Lister Hill国家生物医学通信中心的通信工程分部(CEB)的项目。提供的机器学习数据集(https://ceb.nlm.nih.gov/repositories/malaria-datasets/)是平衡的 - 包括总共27,558张细胞图像,其中被寄生(感染)和未感染(清洁)细胞的实例是相等的。对于模型训练/验证,机器学习数据集以80/20的比例拆分。
#from fastai.data_block import * from fastai.vision import * import pandas as pd # Download and unzip the dataset from # 'https://ceb.nlm.nih.gov/proj/malaria/cell_images.zip' # to PATH PATH = 'data/' DATAPATH = f'{PATH}/cell_images/' files = get_files(f'{DATAPATH}', extensions='.png', recurse=True) # Get label from file_path -- folder's name def get_label(file_path): return 'infected' if '/Parasitized/' in str(file_path) else 'clean' bs=64 # Batch size data = ImageDataBunch.from_name_func(f'{DATAPATH}', fnames=files, label_func=get_label, # Parasitized -> infected; Uninfected -> clean bs = bs, ds_tfms=get_transforms(), size=170 # resize all images ).normalize(imagenet_stats) df = pd.DataFrame(data.y.items) df['category'] = df[0].replace({0:data.classes[0], 1:data.classes[1]}) data
df['category'].value_counts()
clean 11032
infected 11015
Name: category, dtype: int64
# infected cells ratio in the dataset or sample data.y.items.sum()/len(data.y.items)
0.4996144600172359
data.show_batch(rows=3, figsize=(7,6))
模型1 - ResNet-34
Stage 1:我们将获取预训练的机器学习模型,并在我们的数据上训练它的最后一层。我们将准确度用作指标。
learn = create_cnn(data, models.resnet34, pretrained=True, metrics=accuracy) learn.fit_one_cycle(8)
Stage 1的最佳准确度= 0.964979。
结果分析:
interp = ClassificationInterpretation.from_learner(learn) losses,idxs = interp.top_losses() #len(data.valid_ds)==len(losses)==len(idxs) interp.plot_top_losses(9, figsize=(10,10))
以下是混淆矩阵的样子:
doc(interp.plot_top_losses) interp.plot_confusion_matrix(figsize=(5,5)) #, dpi=60)
interp.most_confused(min_val=2)
[('infected', 'clean', 105), ('clean', 'infected', 72)]
Stage 2.现在我们将“Unfreezing”所有层,选择学习率并再次训练整个机器学习模型。
learn.unfreeze() # Enable all layers of NN to learn -- set requires_grad = True #learn.load('stage-1'); learn.lr_find(); learn.recorder.plot()
为了确保最有效的训练,我们选择了一个合适的学习率——大约比上升点低一度。
我们在这里使用的学习率将在从1e-6到1e-5之间。
learn.unfreeze() learn.fit_one_cycle(4, max_lr=slice(1e-6,1e-5))
此时达到的最终精度= 0.966975。然而,它并不是最好的:第3 epoch的准确性更好。尽管如此,ResNet-34模型在Stage 2的准确性略有提高。
interp = ClassificationInterpretation.from_learner(learn) losses,idxs = interp.top_losses() #len(data.valid_ds)==len(losses)==len(idxs) interp.plot_top_losses(9, figsize=(10,10))
以下是混淆矩阵的样子:
doc(interp.plot_top_losses) interp.plot_confusion_matrix(figsize=(5,5))
interp.most_confused(min_val=2)
[('infected', 'clean', 101), ('clean', 'infected', 81)]
这意味着:
- 2764(3.65%)感染细胞中的101个被归类为清洁 - 假阴性 ;
- 2747(2.95%)清洁细胞中的81个被归类为感染 - 假阳性。
模型2 - ResNet-50
Stage 1 只训练最后一层
我们将使用较小的batch size。
data = ImageDataBunch.from_name_func(f'{DATAPATH}', fnames=files, label_func=get_label, bs=bs//2, ds_tfms=get_transforms(), size=170 ).normalize(imagenet_stats) learn = create_cnn(data, models.resnet50, pretrained=True, # Leave only the last layer with requires_grad = True metrics=accuracy) learn.lr_find() learn.recorder.plot()
然后选取学习率对机器学习模型进行8个epochs的训练:
learn.fit_one_cycle(8, max_lr=1e-2)
ResNet-50在Stage 1 取得的最终结果 - 准确度= 0.967882
一些错误分类的图像:
interp = ClassificationInterpretation.from_learner(learn) losses,idxs = interp.top_losses() #len(data.valid_ds)==len(losses)==len(idxs) interp.plot_top_losses(9, figsize=(10,10))
以下是混淆矩阵的样子:
doc(interp.plot_top_losses) interp.plot_confusion_matrix(figsize=(5,5)) #, dpi=60)
interp.most_confused(min_val=2)
[('infected', 'clean', 105), ('clean', 'infected', 72)]
Stage 2
我们现在将Unfreezing ResNet-50模型的所有层,并使用手动选择的学习率再次训练它。
learn.lr_find() learn.recorder.plot()
learn.unfreeze() # Enable all layers of NN to learn -- set requires_grad = True learn.fit_one_cycle(4, max_lr=slice(2e-5,1e-4))
ResNet-50在Stage 2 取得的最佳结果 - 准确度= 0.966975。
一些错误分类的图像:
interp = ClassificationInterpretation.from_learner(learn) losses,idxs = interp.top_losses() #len(data.valid_ds)==len(losses)==len(idxs) interp.plot_top_losses(9, figsize=(10,10))
以下是混淆矩阵的样子:
doc(interp.plot_top_losses) interp.plot_confusion_matrix(figsize=(5,5)) #, dpi=60)
interp.most_confused(min_val=2)
[('infected', 'clean', 103), ('clean', 'infected', 79)]
这意味着:
- 2708例(3.8%)感染细胞中有103例被归类为清洁 - 假阴性 ;
- 2803个中有79个(2.82%)清洁细胞被归类为感染 - 假阳性。
结论
- 看下ResNet-50 at Stage 1错误分类的图像,我们可能得出结论,实际上某些图像可能在数据集中被错误地标记 - 明显感染的图像被标记为清洁,反之亦然。
- 两种机器学习模型的训练准确度相当:
- 但是,Stage 1的ResNet-50更准确。因此,我们可以加载在Stage 1中保存的模型并按原样使用或继续使用不同的超参数(例如另一个学习率)再次训练它以获得更好的结果。
- 每个epoch的训练时间取决于机器学习模型复杂性 - 层数和训练层数(Stage 1的最后一层与Stage 2的所有层)。此外,它可能与批量大小相关 - 较小的batch size(32 vs. 64)用于ResNet-50。