简单的使用PyTorch和Fashion MNIST数据集进行深度学习图像分类

概述

本文的目的是为那些想要使用PyTorch和Fashion MNIST进行简单深度学习图像分类网络的人提供示例参考代码。

在本文中,我们将演示深度学习图像分类网络的所有工作部分,包括加载数据,定义网络,优化GPU上的权重以及评估性能。

整理Fashion MNIST数据集

Fashion MNIST是一个包含70,000个灰度图像和10个类的数据集。

简单的使用PyTorch和Fashion MNIST数据集进行深度学习图像分类

1.检查GPU是否可用

import torch
print(torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

简单的使用PyTorch和Fashion MNIST数据集进行深度学习图像分类

2.下载并加载Fashion MNIST数据集

import torch
from torchvision import datasets, transforms
import helper
# Instructions from here:
# https://www.kaggle.com/ishvindersethi22/fashion-mnist-using-pytorch/data
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
 transforms.Normalize([0.], [0.5])])
# Download and load the training data
trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# Download and load the test data
testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

简单的使用PyTorch和Fashion MNIST数据集进行深度学习图像分类

3.显示样本图像

import matplotlib.pyplot as plt
import torchvision
import numpy as np
def imshow(img):
 img = img / 2 + 0. # unnormalize
 img = img.squeeze()
 plt.imshow(img, cmap='gray')
 
images, label = next(iter(trainloader))
imshow(images[0, :])

简单的使用PyTorch和Fashion MNIST数据集进行深度学习图像分类

4.定义和训练网络

from collections import OrderedDict
from torch import optim, nn
hidden_units = [4, 8, 16]
output_units = 10
class Flatten(nn.Module):
 def forward(self, input):
 return input.view(input.size(0), -1)
model_d = nn.Sequential(OrderedDict([
 ('conv1', nn.Conv2d(1, hidden_units[0], 3, stride=2, padding=1)),
 ('Relu1', nn.ReLU()),
 ('conv2', nn.Conv2d(hidden_units[0], hidden_units[1], 3, stride=2, padding=1)),
 ('Relu2', nn.ReLU()),
 ('conv3', nn.Conv2d(hidden_units[1], hidden_units[2], 3, stride=2, padding=1)),
 ('Relu3', nn.ReLU()),
 ('conv4', nn.Conv2d(hidden_units[2], output_units, 4, stride=4, padding=0)),
 ('log_softmax', nn.LogSoftmax(dim = 1))
]))
model_d.to(device)
optimizer_d = optim.Adam(model_d.parameters(), lr = 0.01)
criterion = nn.NLLLoss()
epochs = 10
for i in range(epochs):
 running_classification_loss = 0
 running_cycle_consistent_loss = 0
 running_loss = 0
 for images, labels in trainloader:
 images, labels = images.to(device), labels.to(device)
 optimizer_d.zero_grad()
 
 # Run classification model
 predicted_labels = model_d(images)
 classification_loss = criterion(Flatten()(predicted_labels), labels)
 
 # Optimize classification weights
 classification_loss.backward()
 optimizer_d.step()
 
 running_classification_loss += classification_loss.item()
 running_loss = running_classification_loss
 else:
 print(f"{i} Training loss: {running_loss/len(trainloader)}")

简单的使用PyTorch和Fashion MNIST数据集进行深度学习图像分类

5.评估网络

total_correct = 0
total_num = 0
for images, labels in testloader:
 images, labels = images.to(device), labels.to(device) 
 ps = Flatten()(torch.exp(model_d(images)))
 predictions = ps.topk(1, 1, True, True)[1].t()
 correct = predictions.eq(labels.view(1, -1))
 
 total_correct += correct.sum().cpu().numpy()
 total_num += images.shape[0]
 
print('Accuracy:', total_correct / float(total_num))
print('Correct Label:', labels[0].item())
print('Predicted Label:', predictions[0, 0].item())
index = 0
imshow(images[index, :].cpu())

简单的使用PyTorch和Fashion MNIST数据集进行深度学习图像分类

小贴士

本文演示了深度学习图像分类网络的所有工作部分,我们可以使用最基本的人工智能应用来进行简单的学习。

  1. 我们加载了Fashion MNIST数据集
  2. 定义一个简单的深度卷积网络
  3. 我们使用GPU上的Adam优化器优化网络权重
  4. 我们评估网络并达到约85%的准确度

相关推荐