使用ONNX将机器学习模型从PyTorch传输到Caffe2和Mobile

使用ONNX将机器学习模型从PyTorch传输到Caffe2和Mobile

在本教程中,我们将介绍如何将PyTorch中的机器学习模型转换为ONNX格式,然后将其加载到Caffe2。然后我们将使用Caffe2’s mobile exporter在移动设备上执行它。

什么是Caffe2和ONNX?

Caffe2(快速特征嵌入的卷积体系结构)是一个可扩展的模块化深度学习框架,采用原始的Caffe框架设计。ONNX(Open Neural Network Exchange)是深度学习模型的一种格式,允许不同的开源AI框架之间的互操作性。ONNX支持Caffe2,PyTorch,MXNet和Microsoft CNTK深度学习框架。

对于本教程,需要安装install onnx,onnx-caffe2和Caffe2。可以使用以下命令通过conda安装onnx和onnx-caffe2:

conda install -c ezyang onnx onnx-caffe2

首先,我们需要导入几个包:

  • io用于处理不同类型的输入和输出。
  • numpy科学计算。
  • nn用于初始化神经网络。
  • torch.utils.model_zoo,它将在给定的URL加载Torch序列化对象。
  • torch.onnx包含以ONNX格式导出模型的函数。

Python代码如下:

import io
import numpy as np
 
from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx

在PyTorch中创建一个超分辨率模型

超分辨率是一种提高图像和视频分辨率的方法。主要用于图像和视频处理。我们将基于PyTorch文档中的官方示例创建一个超分辨率模型。Python代码如下:

import torch.nn as nn
import torch.nn.init as init
 
 
class SuperResolutionNet(nn.Module):
 def __init__(self, upscale_factor, inplace=False):
 super(SuperResolutionNet, self).__init__()
 
 self.relu = nn.ReLU(inplace=inplace)
 self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
 self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
 self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
 self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
 self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
 
 self._initialize_weights()
 
 def forward(self, x):
 x = self.relu(self.conv1(x))
 x = self.relu(self.conv2(x))
 x = self.relu(self.conv3(x))
 x = self.pixel_shuffle(self.conv4(x))
 return x
 
 def _initialize_weights(self):
 init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
 init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
 init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
 init.orthogonal_(self.conv4.weight)
 
torch_model = SuperResolutionNet(upscale_factor=3)

我们不会训练这个机器学习模型,而是为此目的下载预训练过的权重。加载模型后,我们设置随机批量大小,然后使用预先训练的权重初始化模型。Python代码如下:

model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'
batch_size = 1 
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
 map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))
 
torch_model.train(False)

在PyTorch中导出模型

在PyTorch中导出机器学习模型是通过tracing完成的。这是借助该torch.onnx._export() 函数完成的。此函数将执行模型并记录用于计算输出的运算符的跟踪。由于_export运行模型,我们需要提供输入张量x。Python实现如下:

x = torch.randn(batch_size, 1, 224, 224, requires_grad=True)
 
torch_out = torch.onnx._export(torch_model, 
 x, 
 "super_resolution.onnx", 
export_params=True)

torch_out包含了我们将用于确认导出的机器学习模型的输出,当在coff2中运行时,该输出会计算相同的值。

在Caffe2中使用ONNX表示

这是我们确认的Caffe2和PyTorch网络计算相同的值。这涉及以下几个步骤:

  • importing onnx和onnx_caffe2.backend。
  • 加载ONNX ModelProto object。
  • 准备Caffe2后端以执行模型,该模型将ONNX模型转换为可以执行它的Caffe2 NetDef。
  • 在Caffe2中运行模型。
  • 构造从输入名称到Tensor数据的映射。
  • 运行Caffe2 net并验证数值正确性。

Python代码如下:

import onnx
import onnx_caffe2.backend
 
model = onnx.load("super_resolution.onnx")
 
prepared_backend = onnx_caffe2.backend.prepare(model)
 
W = {model.graph.input[0].name: x.data.numpy()}
 
 
c2_out = prepared_backend.run(W)[0]
 
np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3)
 
print("Exported model executed on Caffe2 backend, result looks good")

在移动设备上运行模型

既然模型在Caffe2中,我们可以将其转换为适合在移动设备上运行的格式。这可以使用Caffe2的mobile_exporter来实现。我们生成两个模型原型; 一个用于初始化具有正确权重的模型,另一个用于运行并执行模型。这个过程有几个步骤:

从内部表示中提取workspace 和模型原型。

  • 导入 Caffe2 mobile exporter
  • 调用Export 来获得predict_net,init_net,这两者都是在mobile上运行模型所需的。
  • 将init_net和predict_net保存到文件中,我们将使用该文件在mobile上运行它们。

init_net 具有模型参数和模型输入,同时predict_net 将指导init_net运行时的执行。

Python代码如下:

c2_workspace = prepared_backend.workspace
c2_model = prepared_backend.predict_net
 
from caffe2.python.predictor import mobile_exporter
 
init_net, predict_net = mobile_exporter.Export(c2_workspace, c2_model, c2_model.external_input)
 
with open('init_net.pb', "wb") as fopen:
 fopen.write(init_net.SerializeToString())
with open('predict_net.pb', "wb") as fopen:
fopen.write(predict_net.SerializeToString())

我们使用cat图像运行在coff2中生成的init_net和predict_net,以验证在两次运行中输出(高分辨率cat图像)是相同的。我们从一些标准导入开始:

from caffe2.proto import caffe2_pb2
from caffe2.python import core, net_drawer, net_printer, visualize, workspace, utils
 
import numpy as np
import os
import subprocess
from PIL import Image
from matplotlib import pyplot
from skimage import io, transform

然后我们使用Python的Skimage处理猫图像,就像在神经网络中进行数据处理时一样。加载图像后,我们将其大小调整为224x224尺寸并保存调整后的图像。

img_in = io.imread("catimage.jpg")
 
img = transform.resize(img_in, [224, 224])
 
io.imsave("cat_224x224.jpg", img)

下一步是获取调整大小的cat图像并在Caffe2后端运行超分辨率模型并保存输出图像。以下步骤涉及:

  • 加载已调整大小的图像并将其转换为Ybr格式。
  • 运行我们生成的移动网络,以便正确初始化Caffe2 workspace 。
  • 使用net_printer检查网络的外观,并标识输入和输出blob名称。

Python代码如下:

img = Image.open("cat_224x224.jpg")
img_ycbcr = img.convert('YCbCr')
img_y, img_cb, img_cr = img_ycbcr.split()
 
workspace.RunNetOnce(init_net)
workspace.RunNetOnce(predict_net)
 
print(net_printer.to_string(predict_net))

接下来,我们传入调整大小的cat图像以供机器学习模型处理,然后运行predict_net以获取模型输出。

workspace.FeedBlob("9", np.array(img_y)[np.newaxis, np.newaxis, :, :].astype(np.float32))
 
workspace.RunNetOnce(predict_net)
 
img_out = workspace.FetchBlob("Insert number that was printed above")

接下来,我们构建最终图像并保存。

img_out_y = Image.fromarray(np.uint8((img_out[0, 0]).clip(0, 255)), mode='L')
 
final_img = Image.merge(
 "YCbCr", [
 img_out_y,
 img_cb.resize(img_out_y.size, Image.BICUBIC),
 img_cr.resize(img_out_y.size, Image.BICUBIC),
 ]).convert("RGB")
 
final_img.save("cat_.jpg")

现在让我们在移动设备上执行模型并获取模型输出。执行此操作涉及以下步骤:

  1. 指定将用于在移动设备上执行模型的二进制文件,并导出稍后要检索的模型输出
  2. Pushing 我们之前保存的二进制文件和init_net和proto_net
  3. 将输入图像blob序列化为blob proto,然后将其发送到mobile以供执行
  4. 将输入图像blob推送到adb
  5. 在mobile上运行网络
  6. 从adb获取模型输出并保存到文件
  7. 使用之前的相同步骤恢复模型的输出内容和后处理
  8. 保存图像

Python代码如下:

CAFFE2_MOBILE_BINARY = ('specifiedbinary')
 
os.system('adb push ' + CAFFE2_MOBILE_BINARY + ' /data/local/tmp/')
os.system('adb push init_net.pb /data/local/tmp')
os.system('adb push predict_net.pb /data/local/tmp')
 
with open("input.blobproto", "wb") as fid:
 fid.write(workspace.SerializeBlob("9"))
 
os.system('adb push input.blobproto /data/local/tmp/')
 
os.system(
 'adb shell /data/local/tmp/specifiedbinary ' 
 '--init_net=/data/local/tmp/super_resolution_mobile_init.pb ' 
 '--net=/data/local/tmp/super_resolution_mobile_predict.pb ' 
 '--input=9 ' 
 '--input_file=/data/local/tmp/input.blobproto ' 
 '--output_folder=/data/local/tmp ' 
 '--output=27,9 ' 
 '--iter=1 ' 
 '--caffe2_log_level=0 '
)
 
 
os.system('adb pull /data/local/tmp/27 ./output.blobproto')
 
 
blob_proto = caffe2_pb2.BlobProto()
blob_proto.ParseFromString(open('./output.blobproto').read())
img_out = utils.Caffe2TensorToNumpyArray(blob_proto.tensor)
img_out_y = Image.fromarray(np.uint8((img_out[0,0]).clip(0, 255)), mode='L')
final_img = Image.merge(
 "YCbCr", [
 img_out_y,
 img_cb.resize(img_out_y.size, Image.BICUBIC),
 img_cr.resize(img_out_y.size, Image.BICUBIC),
 ]).convert("RGB")
final_img.save("cat_mobile.jpg")

结论

您可以比较来自Caffe2 执行的cat_.jpg和来自mobile执行的cat_mobile.jpg。如果这两张图片看起来不一样,说明在mobile 执行过程中出现了问题。

相关推荐