Tensorflow使用带有Python示例的数据集和迭代器的TFRecord
在本文中,我们将探索什么是TFRecord,如何在机器学习数据集中使用它,以及如何用迭代器提取数据。我们将讨论一个非常重要但记录较少的话题,如何保存图像在TFRecord。我们还将研究TFRecord规模膨胀的常见问题。
什么是TFRecord ?
TFRecord是一个单独的聚合压缩文件,用于汇总机器学习模型训练/测试期间所需的所有数据(以任何格式存在)。该特定文件可以跨多个系统传输,也独立于将要训练它的机器学习模型。TFRecord文件还可能包含重建原始数据所需的额外开销数据,如果我们在没有TFRecord的情况下进行训练,则可能不需要这些数据。此外,如果数据集非常大,我们可能必须创建多个相似类型的TFRecord文件。
如何建立TFRecord?
TFRecord中的任何数据都必须以bytes列表或float列表或int64列表的形式存储。创建的每个数据列表实体都必须由一个Feature类包装。接下来,每个特性都存储在一个键值对中,键对应分配给每个特性的标题。这些标题将在以后从TFRecord提取数据时使用。创建的字典作为输入传递给Features类。最后,将features对象作为输入传递给示例类。然后这个示例类对象被追加到TFRecord中。对于必须存储在TFRecord中的每种类型的数据,重复上述过程。接下来给出了使用简单数据创建TFRecord的Python代码。
import tensorflow as tf
data_arr = [
{
'int_data': 108,
'float_data': 2.45,
'str_data': 'String 100',
'float_list_data': [256.78, 13.9]
},
{
'int_data': 37,
'float_data': 84.3,
'str_data': 'String 200',
'float_list_data': [1.34, 843.9, 65.22]
}
]
def get_example_object(data_record):
# Convert individual data into a list of int64 or float or bytes
int_list1 = tf.train.Int64List(value = [data_record['int_data']])
float_list1 = tf.train.FloatList(value = [data_record['float_data']])
# Convert string data into list of bytes
str_list1 = tf.train.BytesList(value = [data_record['str_data'].encode('utf-8')])
float_list2 = tf.train.FloatList(value = data_record['float_list_data'])
# Create a dictionary with above lists individually wrapped in Feature
feature_key_value_pair = {
'int_list1': tf.train.Feature(int64_list = int_list1),
'float_list1': tf.train.Feature(float_list = float_list1),
'str_list1': tf.train.Feature(bytes_list = str_list1),
'float_list2': tf.train.Feature(float_list = float_list2)
}
# Create Features object with above feature dictionary
features = tf.train.Features(feature = feature_key_value_pair)
# Create Example object with features
example = tf.train.Example(features = features)
return example
with tf.python_io.TFRecordWriter('example.tfrecord') as tfwriter:
# Iterate through all records
for data_record in data_arr:
example = get_example_object(data_record)
# Append each example into tfrecord
tfwriter.write(example.SerializeToString())
为图像创建TFRecord
现在我们已经基本了解了如何为包含字典和列表的文本类型的数据创建TFRecord,让我们继续添加图像。我们的toy dataset(https://github.com/Prasad9/TFRecord_Images/tree/master/Images)包括10个图像和两种类型,即猫和狗。数据集是PNG和JPEG类型图像的混合。使用的图像如下所示:
解决此问题的一种常见方法是将这些图像的Numpy表示转换为字符串并将其存储到TFRecord中。随着数据表示的格式发生变化,我们必须存储图像形状的开销数据。我们来看看Python实现:
import tensorflow as tf
import os
import matplotlib.image as mpimg
class GenerateTFRecord:
def __init__(self, labels):
self.labels = labels
def convert_image_folder(self, img_folder, tfrecord_file_name):
# Get all file names of images present in folder
img_paths = os.listdir(img_folder)
img_paths = [os.path.abspath(os.path.join(img_folder, i)) for i in img_paths]
with tf.python_io.TFRecordWriter(tfrecord_file_name) as writer:
for img_path in img_paths:
example = self._convert_image(img_path)
writer.write(example.SerializeToString())
def _convert_image(self, img_path):
label = self._get_label_with_filename(img_path)
image_data = mpimg.imread(img_path)
# Convert image to string data
image_str = image_data.tostring()
# Store shape of image for reconstruction purposes
img_shape = image_data.shape
# Get filename
filename = os.path.basename(img_path)
example = tf.train.Example(features = tf.train.Features(feature = {
'filename': tf.train.Feature(bytes_list = tf.train.BytesList(value = [filename.encode('utf-8')])),
'rows': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[0]])),
'cols': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[1]])),
'channels': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[2]])),
'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image_str])),
'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [label]))
}))
return example
def _get_label_with_filename(self, filename):
basename = os.path.basename(filename).split('.')[0]
basename = basename.split('_')[0]
return self.labels[basename]
if __name__ == '__main__':
labels = {'cat': 0, 'dog': 1}
t = GenerateTFRecord(labels)
t.convert_image_folder('Images', 'images.tfrecord')
我们生成了一个名为images.tfrecord的文件。这个文件的大小是惊人的20.3 MB,而如果你把数据集中的单个图像文件的大小加起来,只有1.15 MB,这可能是很多人不喜欢TFRecord和许多开发人员开始停止使用TFRecord的主要原因之一。
为什么TFRecord在内存中变得如此巨大?
我们必须分析这个问题,开始研究每个图像的shape ,例如dog_2.jpg图片。该图像的形状是(1414,943,3)。将这些维度中的每一个相乘,即1414 x 943 x 3等于4000206.因此,在Numpy表示内(假设数据类型为uint8),图像由总共4000206个整数表示。当我们调用tostring()的Numpy方法时,接下来将这些4000206数字顺序存储在二进制字符串中。
为了给出字符串长度的一个小例子(虽然技术上不正确),假设每个uint8数都大于100.所以当这个数字转换为字符串时,它总共代表三个字符。然后每个数字必须用分隔符分隔,假设正在使用(,)。在我们的Numpy示例中总共有4000206个数字。所以我们预期的字符串长度是:
4000206 x(3个字符+ 1个分隔符字符)= 16000824个字符。
如果一个字符是一个字节,则转换为15.25 MB。
如何克服TFRecord大小问题?
让我们看一下图像的另一个属性:图像的存储大小。训练中使用的许多图像的存储大小通常很小,无论图像shape如何,存储大小的变化从很少的KBs到100s的KB不等。因此,让我们将图像的字节直接存储到TFRecord中。Tensorflow为我们提供了tf.gfile.FastGFile类,它可以以字节格式读取图像。让我们看一下修改后的Python代码:
class GenerateTFRecord:
def _convert_image(self, img_path):
label = self._get_label_with_filename(img_path)
img_shape = mpimg.imread(img_path).shape
filename = os.path.basename(img_path)
# Read image data in terms of bytes
with tf.gfile.FastGFile(img_path, 'rb') as fid:
image_data = fid.read()
example = tf.train.Example(features = tf.train.Features(feature = {
'filename': tf.train.Feature(bytes_list = tf.train.BytesList(value = [filename.encode('utf-8')])),
'rows': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[0]])),
'cols': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[1]])),
'channels': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[2]])),
'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image_data])),
'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [label])),
}))
return example
现在我们的images.tfrecord文件的大小是1.2 MB,几乎与各个图像的大小相同。
进一步减少TFRecord大小
现在,让我们尝试进一步降低TFRecord的大小。PNG图像倾向于使用更清晰的边缘细节捕获更多信息。这是以增加图像存储大小为代价的。转换为JPEG图像将无限制地模糊您的图像,但它会奖励您可测量的存储大小减少量。Tensorflow还为您提供转换时希望保留的质量。让我们看看Python实现的相关代码部分。
class GenerateTFRecord:
def __init__(self, labels):
self.labels = labels
self._create_graph()
# Create graph to convert PNG image data to JPEG data
def _create_graph(self):
tf.reset_default_graph()
self.png_img_pl = tf.placeholder(tf.string)
png_enc = tf.image.decode_png(self.png_img_pl, channels = 3)
# Set how much quality of image you would like to retain while conversion
self.png_to_jpeg = tf.image.encode_jpeg(png_enc, format = 'rgb', quality = 100)
def _is_png_image(self, filename):
ext = os.path.splitext(filename)[1].lower()
return ext == '.png'
# Run graph to convert PNG image data to JPEG data
def _convert_png_to_jpeg(self, img):
sess = tf.get_default_session()
return sess.run(self.png_to_jpeg, feed_dict = {self.png_img_pl: img})
def _convert_image(self, img_path):
label = self._get_label_with_filename(img_path)
img_shape = mpimg.imread(img_path).shape
filename = os.path.basename(img_path).split('.')[0]
# Read image data in terms of bytes
with tf.gfile.FastGFile(img_path, 'rb') as fid:
image_data = fid.read()
# Encode PNG data to JPEG data
if self._is_png_image(img_path):
image_data = self._convert_png_to_jpeg(image_data)
example = tf.train.Example(features = tf.train.Features(feature = {
'filename': tf.train.Feature(bytes_list = tf.train.BytesList(value = [filename.encode('utf-8')])),
'rows': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[0]])),
'cols': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[1]])),
'channels': tf.train.Feature(int64_list = tf.train.Int64List(value = [3])),
'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image_data])),
'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [label])),
}))
return example
如果alpha通道(即RGBA的A)是机器学习模型输入的一部分,或者机器学习模型确实需要图像中更加清晰的细节,请不要执行此步骤。
现在,保持100%的编码质量,我们将早期的TFRecord文件减少了1.2 MB到579.5 KB。但是如果我们将质量略微降低到95%,我们可以将TFRecord文件大小减小到402.4 KB。我在以后的质量变化过程中绘制了对数据集图像的影响,但现在我已经将TFRecord文件大小的变化列表在下面,以获得不同的质量。
我们能进一步减少吗?
是的,我们可以进一步改进。但是这一次,我们不会改进我们的代码,但我们将使用外部图像优化工具。一个好的图像优化工具将减少图像的大小,绝对不会降低图像质量。我总是使用Trimage。通过使用Trimage,我将以下结果制成表格:
在我们开始转换为TFRecord之前,这个特定步骤可能应该用作开始步骤。
从TFRecord中提取数据(简要说明)
现在我们的TFRecords准备就绪,是时候将它们发送到训练管道了。第一步是使用所有TFRecord文件路径初始化TFRecordDataset。之后,我们必须提取TFRecords中存在的各种特征。我们在此步骤的早期指定TFRecord形成期间使用的各种键。如果我们事先知道每个数据记录的bytes列表或float或int64中存在的项目数,我们可以使用FixedLenFeature,否则,我们使用VarLenFeature类。接下来,API parse_single_example提取每个数据记录的字典对象。让我们看一下之前用简单文本字典数据创建的TFRecord的提取过程。Python代码如下:
import tensorflow as tf
def extract_fn(data_record):
features = {
# Extract features using the keys set during creation
'int_list1': tf.FixedLenFeature([], tf.int64),
'float_list1': tf.FixedLenFeature([], tf.float32),
'str_list1': tf.FixedLenFeature([], tf.string),
# If size is different of different records, use VarLenFeature
'float_list2': tf.VarLenFeature(tf.float32)
}
sample = tf.parse_single_example(data_record, features)
return sample
# Initialize all tfrecord paths
dataset = tf.data.TFRecordDataset(['example.tfrecord'])
dataset = dataset.map(extract_fn)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
try:
while True:
data_record = sess.run(next_element)
print(data_record)
except:
pass
从TFRecord中提取图像
我们扩展了提取简单TFRecord文件的相同概念,以便从中提取图像。在tf.image.decode_image API 的帮助下,我们可以解码任何格式的图像。作为预防措施,我们验证解码图像的shape是否与TFRecord中存储的行,列和通道的开销数据相匹配。让我们深入研究TFRecord提取图像的Python实现。
import tensorflow as tf
import os
import shutil
import matplotlib.image as mpimg
import numpy as np
class TFRecordExtractor:
def __init__(self, tfrecord_file):
self.tfrecord_file = os.path.abspath(tfrecord_file)
def _extract_fn(self, tfrecord):
# Extract features using the keys set during creation
features = {
'filename': tf.FixedLenFeature([], tf.string),
'rows': tf.FixedLenFeature([], tf.int64),
'cols': tf.FixedLenFeature([], tf.int64),
'channels': tf.FixedLenFeature([], tf.int64),
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
}
# Extract the data record
sample = tf.parse_single_example(tfrecord, features)
image = tf.image.decode_image(sample['image'])
img_shape = tf.stack([sample['rows'], sample['cols'], sample['channels']])
label = sample['label']
filename = sample['filename']
return [image, label, filename, img_shape]
def extract_image(self):
# Create folder to store extracted images
folder_path = './ExtractedImages'
shutil.rmtree(folder_path, ignore_errors = True)
os.mkdir(folder_path)
# Pipeline of dataset and iterator
dataset = tf.data.TFRecordDataset([self.tfrecord_file])
dataset = dataset.map(self._extract_fn)
iterator = dataset.make_one_shot_iterator()
next_image_data = iterator.get_next()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
try:
# Keep extracting data till TFRecord is exhausted
while True:
image_data = sess.run(next_image_data)
# Check if image shape is same after decoding
if not np.array_equal(image_data[0].shape, image_data[3]):
print('Image {} not decoded properly'.format(image_data[2]))
continue
save_path = os.path.abspath(os.path.join(folder_path, image_data[2].decode('utf-8')))
mpimg.imsave(save_path, image_data[0])
print('Save path = ', save_path, ', Label = ', image_data[1])
except:
pass
if __name__ == '__main__':
t = TFRecordExtractor('./images.tfrecord')
t.extract_image()
结果
最后,让我们根据图像格式和转换次数显示图像的获取结果。
这是JPEG和PNG图像的输出集,没有PNG图像转换为JPEG。
这是转换时PNG图像的输出集,转换质量分别保持在100%,95%和90%。但请注意,输入PNG图像的透明部分已转换为黑色。